diff --git a/whisperlivekit/TokensAlignment.py b/whisperlivekit/TokensAlignment.py index f2b0218..2cb2815 100644 --- a/whisperlivekit/TokensAlignment.py +++ b/whisperlivekit/TokensAlignment.py @@ -1,6 +1,10 @@ -from whisperlivekit.timed_objects import Line, SilentLine, format_time, SpeakerSegment, Silence -from whisperlivekit.timed_objects import PunctuationSegment from time import time +from typing import Optional + +from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence +from whisperlivekit.timed_objects import PunctuationSegment + +ALIGNMENT_TIME_TOLERANCE = 0.2 # seconds class TokensAlignment: @@ -12,15 +16,16 @@ class TokensAlignment: self._diarization_index = 0 self._translation_index = 0 - self.all_tokens = [] - self.all_diarization_segments = [] + self.all_tokens : list[ASRToken] = [] + self.all_diarization_segments: list[SpeakerSegment] = [] self.all_translation_segments = [] - self.new_tokens = [] + self.new_tokens : list[ASRToken] = [] + self.new_diarization: list[SpeakerSegment] = [] self.new_translation = [] - self.new_diarization = [] self.new_tokens_buffer = [] - self.sep = ' ' + self.sep = sep if sep is not None else ' ' + self.beg_loop = None def update(self): self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] @@ -32,7 +37,10 @@ class TokensAlignment: self.all_diarization_segments.extend(self.new_diarization) self.all_translation_segments.extend(self.new_translation) - def create_lines_from_tokens(self, current_silence, beg_loop): + def get_lines(self, current_silence): + """ + In the case without diarization + """ lines = [] current_line_tokens = [] for token in self.all_tokens: @@ -40,7 +48,7 @@ class TokensAlignment: if current_line_tokens: lines.append(Line().build_from_tokens(current_line_tokens)) current_line_tokens = [] - end_silence = token.end if token.has_ended else time() - beg_loop + end_silence = token.end if token.has_ended else time() - self.beg_loop if lines and lines[-1].is_silent(): lines[-1].end = end_silence else: @@ -53,7 +61,7 @@ class TokensAlignment: if current_line_tokens: lines.append(Line().build_from_tokens(current_line_tokens)) if current_silence: - end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop + end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop if lines and lines[-1].is_silent(): lines[-1].end = end_silence else: @@ -64,22 +72,104 @@ class TokensAlignment: return lines - def align_tokens(self): - if not self.diarization: - pass - # return self.all_tokens - def compute_punctuations_segments(self): + def _get_asr_tokens(self) -> list[ASRToken]: + return [token for token in self.all_tokens if isinstance(token, ASRToken)] + + def _tokens_to_text(self, tokens: list[ASRToken]) -> str: + return ''.join(token.text for token in tokens) + + def _extract_detected_language(self, tokens: list[ASRToken]): + for token in tokens: + if getattr(token, 'detected_language', None): + return token.detected_language + return None + + def _speaker_display_id(self, raw_speaker) -> int: + if isinstance(raw_speaker, int): + speaker_index = raw_speaker + else: + digits = ''.join(ch for ch in str(raw_speaker) if ch.isdigit()) + speaker_index = int(digits) if digits else 0 + return speaker_index + 1 if speaker_index >= 0 else 0 + + def _line_from_tokens(self, tokens: list[ASRToken], speaker: int) -> Line: + line = Line().build_from_tokens(tokens) + line.speaker = speaker + detected_language = self._extract_detected_language(tokens) + if detected_language: + line.detected_language = detected_language + return line + + def _find_initial_diar_index(self, diar_segments: list[SpeakerSegment], start_time: float) -> int: + for idx, segment in enumerate(diar_segments): + if segment.end + ALIGNMENT_TIME_TOLERANCE >= start_time: + return idx + return len(diar_segments) + + def _find_speaker_for_token(self, token: ASRToken, diar_segments: list[SpeakerSegment], diar_idx: int): + if not diar_segments: + return None, diar_idx + idx = min(diar_idx, len(diar_segments) - 1) + midpoint = (token.start + token.end) / 2 if token.end is not None else token.start + + while idx < len(diar_segments) and diar_segments[idx].end + ALIGNMENT_TIME_TOLERANCE < midpoint: + idx += 1 + + candidate_indices = [] + if idx < len(diar_segments): + candidate_indices.append(idx) + if idx > 0: + candidate_indices.append(idx - 1) + + for candidate_idx in candidate_indices: + segment = diar_segments[candidate_idx] + seg_start = (segment.start or 0) - ALIGNMENT_TIME_TOLERANCE + seg_end = (segment.end or 0) + ALIGNMENT_TIME_TOLERANCE + if seg_start <= midpoint <= seg_end: + return segment.speaker, candidate_idx + + return None, idx + + def _build_lines_for_tokens(self, tokens: list[ASRToken], diar_segments: list[SpeakerSegment], diar_idx: int): + if not tokens: + return [], diar_idx + + segment_lines: list[Line] = [] + current_tokens: list[ASRToken] = [] + current_speaker = None + pointer = diar_idx + + for token in tokens: + speaker_raw, pointer = self._find_speaker_for_token(token, diar_segments, pointer) + if speaker_raw is None: + return [], diar_idx + speaker = self._speaker_display_id(speaker_raw) + if current_speaker is None or current_speaker != speaker: + if current_tokens: + segment_lines.append(self._line_from_tokens(current_tokens, current_speaker)) + current_tokens = [token] + current_speaker = speaker + else: + current_tokens.append(token) + + if current_tokens: + segment_lines.append(self._line_from_tokens(current_tokens, current_speaker)) + + return segment_lines, pointer + + def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None): """Compute segments of text between punctuation marks. Returns a list of PunctuationSegment objects, each representing the text from the start (or previous punctuation) to the current punctuation mark. """ - if not self.all_tokens: + tokens = tokens if tokens is not None else self._get_asr_tokens() + if not tokens: return [] punctuation_indices = [ - i for i, token in enumerate(self.all_tokens) + i for i, token in enumerate[ASRToken](tokens) if token.is_punctuation() ] if not punctuation_indices: @@ -91,7 +181,7 @@ class TokensAlignment: end_idx = punct_idx if start_idx <= end_idx: segment = PunctuationSegment.from_token_range( - tokens=self.all_tokens, + tokens=tokens, token_index_start=start_idx, token_index_end=end_idx, punctuation_token_index=punct_idx @@ -109,4 +199,42 @@ class TokensAlignment: merged[-1].end = segment.end else: merged.append(segment) - return merged \ No newline at end of file + return merged + + def get_lines(self, diarization=False, translation=False): + """ + Align diarization speaker segments with punctuation-delimited transcription + segments (see docs/alignement_principles.md). + """ + tokens = self._get_asr_tokens() + if not tokens: + return [], '' + + punctuation_segments = self.compute_punctuations_segments(tokens=tokens) + diar_segments = self.concatenate_diar_segments() + + if not punctuation_segments or not diar_segments: + return [], self._tokens_to_text(tokens) + + max_diar_end = diar_segments[-1].end + if max_diar_end is None: + return [], self._tokens_to_text(tokens) + + lines: list[Line] = [] + last_consumed_index = -1 + diar_idx = self._find_initial_diar_index(diar_segments, tokens[0].start or 0) + + for segment in punctuation_segments: + if segment.end is None or segment.end > max_diar_end: + break + slice_tokens = tokens[segment.token_index_start:segment.token_index_end + 1] + segment_lines, diar_idx = self._build_lines_for_tokens(slice_tokens, diar_segments, diar_idx) + if not segment_lines: + break + lines.extend(segment_lines) + last_consumed_index = segment.token_index_end + + buffer_tokens = tokens[last_consumed_index + 1:] if last_consumed_index + 1 < len(tokens) else [] + buffer_diarization = self._tokens_to_text(buffer_tokens) + + return lines, buffer_diarization \ No newline at end of file diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 8263ef3..cd7b9f9 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) SENTINEL = object() # unique sentinel object for end of stream marker -MILENCE_DURATION = 3 +MIN_DURATION_REAL_SILENCE = 5 def cut_at(cumulative_pcm, cut_sec): cumulative_len = 0 @@ -165,7 +165,7 @@ class AudioProcessor: self.current_silence.is_starting=False self.current_silence.has_ended=True self.current_silence.compute_duration() - if self.current_silence.duration > MILENCE_DURATION: + if self.current_silence.duration > MIN_DURATION_REAL_SILENCE: self.state_light.new_tokens.append(self.current_silence) await self._push_silence_event() self.current_silence = None @@ -410,57 +410,32 @@ class AudioProcessor: continue self.tokens_alignment.update() - lines = self.tokens_alignment.create_lines_from_tokens(self.current_silence, self.beg_loop) - undiarized_text = '' + lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines( + diarization=self.args.diarization, + translation=self.args.translation + ) state = await self.get_current_state() - # self.tokens_alignment.compute_punctuations_segments() - # lines, undiarized_text = format_output( - # state, - # self.current_silence, - # args = self.args, - # sep=self.sep - # ) - if lines and lines[-1].speaker == -2: - buffer_transcription = Transcript() - else: - buffer_transcription = state.buffer_transcription - - buffer_diarization = '' - if undiarized_text: - buffer_diarization = self.sep.join(undiarized_text) - - async with self.lock: - self.state.end_attributed_speaker = state.end_attributed_speaker buffer_translation_text = '' - if state.buffer_translation: - raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation) - if raw_buffer_translation: - buffer_translation_text = raw_buffer_translation.strip() - + buffer_transcription_text = '' + buffer_diarization_text = '' + response_status = "active_transcription" - if not state.tokens and not buffer_transcription and not buffer_diarization: + if not lines and not buffer_transcription_text and not buffer_diarization_text: response_status = "no_audio_detected" - lines = [] - elif not lines: - lines = [Line( - speaker=1, - start=state.end_buffer, - end=state.end_buffer - )] - + response = FrontData( status=response_status, lines=lines, - buffer_transcription=buffer_transcription.text.strip(), - buffer_diarization=buffer_diarization, + buffer_transcription=buffer_transcription_text, + buffer_diarization=buffer_diarization_text, buffer_translation=buffer_translation_text, remaining_time_transcription=state.remaining_time_transcription, remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 ) should_push = (response != self.last_response_content) - if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"): + if should_push: yield response self.last_response_content = response @@ -582,6 +557,7 @@ class AudioProcessor: if not self.beg_loop: self.beg_loop = time() self.current_silence = Silence(start=0.0, is_starting=True) + self.tokens_alignment.beg_loop = self.beg_loop if not message: logger.info("Empty audio message received, initiating stop sequence.") diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index c2310ec..1a04d4a 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -162,8 +162,10 @@ class Line(TimedText): return self.speaker == -2 class SilentLine(Line): - speaker = -2 - text = '' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.speaker = -2 + self.text = '' @dataclass @@ -192,8 +194,10 @@ class FrontData(): return _dict @dataclass -class PunctuationSegment(TimedText): +class PunctuationSegment(): """Represents a segment of text between punctuation marks.""" + start: Optional[float] + end: Optional[float] token_index_start: int token_index_end: int punctuation_token_index: int