diff --git a/whisperlivekit/TokensAlignment.py b/whisperlivekit/TokensAlignment.py deleted file mode 100644 index 2cb2815..0000000 --- a/whisperlivekit/TokensAlignment.py +++ /dev/null @@ -1,240 +0,0 @@ -from time import time -from typing import Optional - -from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence -from whisperlivekit.timed_objects import PunctuationSegment - -ALIGNMENT_TIME_TOLERANCE = 0.2 # seconds - - -class TokensAlignment: - - def __init__(self, state, args, sep): - self.state = state - self.diarization = args.diarization - self._tokens_index = 0 - self._diarization_index = 0 - self._translation_index = 0 - - self.all_tokens : list[ASRToken] = [] - self.all_diarization_segments: list[SpeakerSegment] = [] - self.all_translation_segments = [] - - self.new_tokens : list[ASRToken] = [] - self.new_diarization: list[SpeakerSegment] = [] - self.new_translation = [] - self.new_tokens_buffer = [] - self.sep = sep if sep is not None else ' ' - self.beg_loop = None - - def update(self): - self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] - self.new_diarization, self.state.new_diarization = self.state.new_diarization, [] - self.new_translation, self.state.new_translation = self.state.new_translation, [] - self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, [] - - self.all_tokens.extend(self.new_tokens) - self.all_diarization_segments.extend(self.new_diarization) - self.all_translation_segments.extend(self.new_translation) - - def get_lines(self, current_silence): - """ - In the case without diarization - """ - lines = [] - current_line_tokens = [] - for token in self.all_tokens: - if type(token) == Silence: - if current_line_tokens: - lines.append(Line().build_from_tokens(current_line_tokens)) - current_line_tokens = [] - end_silence = token.end if token.has_ended else time() - self.beg_loop - if lines and lines[-1].is_silent(): - lines[-1].end = end_silence - else: - lines.append(SilentLine( - start = token.start, - end = end_silence - )) - else: - current_line_tokens.append(token) - if current_line_tokens: - lines.append(Line().build_from_tokens(current_line_tokens)) - if current_silence: - end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop - if lines and lines[-1].is_silent(): - lines[-1].end = end_silence - else: - lines.append(SilentLine( - start = current_silence.start, - end = end_silence - )) - - return lines - - - def _get_asr_tokens(self) -> list[ASRToken]: - return [token for token in self.all_tokens if isinstance(token, ASRToken)] - - def _tokens_to_text(self, tokens: list[ASRToken]) -> str: - return ''.join(token.text for token in tokens) - - def _extract_detected_language(self, tokens: list[ASRToken]): - for token in tokens: - if getattr(token, 'detected_language', None): - return token.detected_language - return None - - def _speaker_display_id(self, raw_speaker) -> int: - if isinstance(raw_speaker, int): - speaker_index = raw_speaker - else: - digits = ''.join(ch for ch in str(raw_speaker) if ch.isdigit()) - speaker_index = int(digits) if digits else 0 - return speaker_index + 1 if speaker_index >= 0 else 0 - - def _line_from_tokens(self, tokens: list[ASRToken], speaker: int) -> Line: - line = Line().build_from_tokens(tokens) - line.speaker = speaker - detected_language = self._extract_detected_language(tokens) - if detected_language: - line.detected_language = detected_language - return line - - def _find_initial_diar_index(self, diar_segments: list[SpeakerSegment], start_time: float) -> int: - for idx, segment in enumerate(diar_segments): - if segment.end + ALIGNMENT_TIME_TOLERANCE >= start_time: - return idx - return len(diar_segments) - - def _find_speaker_for_token(self, token: ASRToken, diar_segments: list[SpeakerSegment], diar_idx: int): - if not diar_segments: - return None, diar_idx - idx = min(diar_idx, len(diar_segments) - 1) - midpoint = (token.start + token.end) / 2 if token.end is not None else token.start - - while idx < len(diar_segments) and diar_segments[idx].end + ALIGNMENT_TIME_TOLERANCE < midpoint: - idx += 1 - - candidate_indices = [] - if idx < len(diar_segments): - candidate_indices.append(idx) - if idx > 0: - candidate_indices.append(idx - 1) - - for candidate_idx in candidate_indices: - segment = diar_segments[candidate_idx] - seg_start = (segment.start or 0) - ALIGNMENT_TIME_TOLERANCE - seg_end = (segment.end or 0) + ALIGNMENT_TIME_TOLERANCE - if seg_start <= midpoint <= seg_end: - return segment.speaker, candidate_idx - - return None, idx - - def _build_lines_for_tokens(self, tokens: list[ASRToken], diar_segments: list[SpeakerSegment], diar_idx: int): - if not tokens: - return [], diar_idx - - segment_lines: list[Line] = [] - current_tokens: list[ASRToken] = [] - current_speaker = None - pointer = diar_idx - - for token in tokens: - speaker_raw, pointer = self._find_speaker_for_token(token, diar_segments, pointer) - if speaker_raw is None: - return [], diar_idx - speaker = self._speaker_display_id(speaker_raw) - if current_speaker is None or current_speaker != speaker: - if current_tokens: - segment_lines.append(self._line_from_tokens(current_tokens, current_speaker)) - current_tokens = [token] - current_speaker = speaker - else: - current_tokens.append(token) - - if current_tokens: - segment_lines.append(self._line_from_tokens(current_tokens, current_speaker)) - - return segment_lines, pointer - - def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None): - """Compute segments of text between punctuation marks. - - Returns a list of PunctuationSegment objects, each representing - the text from the start (or previous punctuation) to the current punctuation mark. - """ - - tokens = tokens if tokens is not None else self._get_asr_tokens() - if not tokens: - return [] - punctuation_indices = [ - i for i, token in enumerate[ASRToken](tokens) - if token.is_punctuation() - ] - if not punctuation_indices: - return [] - - segments = [] - for i, punct_idx in enumerate(punctuation_indices): - start_idx = punctuation_indices[i - 1] + 1 if i > 0 else 0 - end_idx = punct_idx - if start_idx <= end_idx: - segment = PunctuationSegment.from_token_range( - tokens=tokens, - token_index_start=start_idx, - token_index_end=end_idx, - punctuation_token_index=punct_idx - ) - segments.append(segment) - return segments - - - def concatenate_diar_segments(self): - if not self.all_diarization_segments: - return [] - merged = [self.all_diarization_segments[0]] - for segment in self.all_diarization_segments[1:]: - if segment.speaker == merged[-1].speaker: - merged[-1].end = segment.end - else: - merged.append(segment) - return merged - - def get_lines(self, diarization=False, translation=False): - """ - Align diarization speaker segments with punctuation-delimited transcription - segments (see docs/alignement_principles.md). - """ - tokens = self._get_asr_tokens() - if not tokens: - return [], '' - - punctuation_segments = self.compute_punctuations_segments(tokens=tokens) - diar_segments = self.concatenate_diar_segments() - - if not punctuation_segments or not diar_segments: - return [], self._tokens_to_text(tokens) - - max_diar_end = diar_segments[-1].end - if max_diar_end is None: - return [], self._tokens_to_text(tokens) - - lines: list[Line] = [] - last_consumed_index = -1 - diar_idx = self._find_initial_diar_index(diar_segments, tokens[0].start or 0) - - for segment in punctuation_segments: - if segment.end is None or segment.end > max_diar_end: - break - slice_tokens = tokens[segment.token_index_start:segment.token_index_end + 1] - segment_lines, diar_idx = self._build_lines_for_tokens(slice_tokens, diar_segments, diar_idx) - if not segment_lines: - break - lines.extend(segment_lines) - last_consumed_index = segment.token_index_end - - buffer_tokens = tokens[last_consumed_index + 1:] if last_consumed_index + 1 < len(tokens) else [] - buffer_diarization = self._tokens_to_text(buffer_tokens) - - return lines, buffer_diarization \ No newline at end of file diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index cd7b9f9..e6d092c 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -7,9 +7,8 @@ import traceback from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.silero_vad_iterator import FixedVADIterator -from whisperlivekit.results_formater import format_output from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState -from whisperlivekit.TokensAlignment import TokensAlignment +from whisperlivekit.tokens_alignment import TokensAlignment logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -392,8 +391,8 @@ class AudioProcessor: self.translation.insert_tokens(tokens_to_process) translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process) async with self.lock: - self.state.translation_validated_segments = translation_validated_segments - self.state.buffer_translation = buffer_translation + self.state_light.new_translation = translation_validated_segments + self.state_light.new_translation_buffer = buffer_translation except Exception as e: logger.warning(f"Exception in translation_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") @@ -412,11 +411,11 @@ class AudioProcessor: self.tokens_alignment.update() lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines( diarization=self.args.diarization, - translation=self.args.translation + translation=bool(self.translation), + current_silence=self.current_silence ) state = await self.get_current_state() - buffer_translation_text = '' buffer_transcription_text = '' buffer_diarization_text = '' diff --git a/whisperlivekit/remove_silences.py b/whisperlivekit/remove_silences.py deleted file mode 100644 index 368d34c..0000000 --- a/whisperlivekit/remove_silences.py +++ /dev/null @@ -1,103 +0,0 @@ -from whisperlivekit.timed_objects import ASRToken -from time import time -import re - -MIN_SILENCE_DURATION = 4 #in seconds -END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important -END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences - -def blank_to_silence(tokens): - full_string = ''.join([t.text for t in tokens]) - patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')] - matches = [] - for pattern in patterns: - for m in pattern.finditer(full_string): - matches.append({ - 'start': m.start(), - 'end': m.end() - }) - if matches: - # cleaned = pattern.sub(' ', full_string).strip() - # print("Cleaned:", cleaned) - cumulated_len = 0 - silence_token = None - cleaned_tokens = [] - for token in tokens: - if matches: - start = cumulated_len - end = cumulated_len + len(token.text) - cumulated_len = end - if start >= matches[0]['start'] and end <= matches[0]['end']: - if silence_token: #previous token was already silence - silence_token.start = min(silence_token.start, token.start) - silence_token.end = max(silence_token.end, token.end) - else: #new silence - silence_token = ASRToken( - start=token.start, - end=token.end, - speaker=-2, - ) - else: - if silence_token: #there was silence but no more - if silence_token.duration() >= MIN_SILENCE_DURATION: - cleaned_tokens.append( - silence_token - ) - silence_token = None - matches.pop(0) - cleaned_tokens.append(token) - # print(cleaned_tokens) - return cleaned_tokens - return tokens - -def no_token_to_silence(tokens): - new_tokens = [] - silence_token = None - for token in tokens: - if token.speaker == -2: - if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too - new_tokens[-1].end = token.end - else: - new_tokens.append(token) - - last_end = new_tokens[-1].end if new_tokens else 0.0 - if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap - if new_tokens and new_tokens[-1].speaker == -2: - new_tokens[-1].end = token.start - else: - silence_token = ASRToken( - start=last_end, - end=token.start, - speaker=-2, - ) - new_tokens.append(silence_token) - - if token.speaker != -2: - new_tokens.append(token) - return new_tokens - -def ends_with_silence(tokens, beg_loop, vac_detected_silence): - current_time = time() - (beg_loop if beg_loop else 0.0) - last_token = tokens[-1] - if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION): - if last_token.speaker == -2: - last_token.end = current_time - else: - tokens.append( - ASRToken( - start=tokens[-1].end, - end=current_time, - speaker=-2, - ) - ) - return tokens - - -def handle_silences(tokens, beg_loop, vac_detected_silence): - if not tokens: - return [] - tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text - tokens = no_token_to_silence(tokens) - tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence) - return tokens - \ No newline at end of file diff --git a/whisperlivekit/result_diarization.md b/whisperlivekit/result_diarization.md deleted file mode 100644 index 78607b3..0000000 --- a/whisperlivekit/result_diarization.md +++ /dev/null @@ -1,60 +0,0 @@ -########### WHAT IS PRODUCED: ########### - -SPEAKER 1 0:00:04 - 0:00:06 -Transcription technology has improved so much in the past - -SPEAKER 1 0:00:07 - 0:00:12 -years. Have you noticed how accurate real-time speech detects is now? - -SPEAKER 2 0:00:12 - 0:00:12 -Absolutely - -SPEAKER 1 0:00:13 - 0:00:13 -. - -SPEAKER 2 0:00:14 - 0:00:14 -I - -SPEAKER 1 0:00:14 - 0:00:17 -use it all the time for taking notes during meetings. - -SPEAKER 2 0:00:17 - 0:00:17 -It - -SPEAKER 1 0:00:17 - 0:00:22 -'s amazing how it can recognize different speakers, and even add punctuation. - -SPEAKER 2 0:00:22 - 0:00:22 -Yeah - -SPEAKER 1 0:00:23 - 0:00:26 -, but sometimes noise can still cause mistakes. - -SPEAKER 3 0:00:26 - 0:00:27 -Does - -SPEAKER 1 0:00:27 - 0:00:28 -this system handle that - -SPEAKER 1 0:00:29 - 0:00:29 -? - -SPEAKER 3 0:00:29 - 0:00:29 -It - -SPEAKER 1 0:00:29 - 0:00:33 -does a pretty good job filtering noise, especially with models that use voice activity - -########### WHAT SHOULD BE PRODUCED: ########### - -SPEAKER 1 0:00:04 - 0:00:12 -Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now? - -SPEAKER 2 0:00:12 - 0:00:22 -Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation. - -SPEAKER 3 0:00:22 - 0:00:28 -Yeah, but sometimes noise can still cause mistakes. Does this system handle that well? - -SPEAKER 1 0:00:29 - 0:00:29 -It does a pretty good job filtering noise, especially with models that use voice activity \ No newline at end of file diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py deleted file mode 100644 index 07cbbb5..0000000 --- a/whisperlivekit/results_formater.py +++ /dev/null @@ -1,257 +0,0 @@ - -import logging -import re -from whisperlivekit.remove_silences import handle_silences -from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment -from typing import List - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -CHECK_AROUND = 4 -DEBUG = False - -def next_punctuation_change(i, tokens): - for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)): - if tokens[ind].is_punctuation(): - return ind - return None - -def next_speaker_change(i, tokens, speaker): - for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1): - token = tokens[ind] - if token.is_punctuation(): - break - if token.speaker != speaker: - return ind, token.speaker - return None, speaker - -def new_line( - token, -): - return Line( - speaker = token.corrected_speaker, - text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""), - start = token.start, - end = token.end, - detected_language=token.detected_language - ) - -def append_token_to_last_line(lines, sep, token): - if not lines: - lines.append(new_line(token)) - else: - if token.text: - lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "") - lines[-1].end = token.end - if not lines[-1].detected_language and token.detected_language: - lines[-1].detected_language = token.detected_language - -def extract_number(s) -> int: - """Extract number from speaker string (for diart compatibility).""" - if isinstance(s, int): - return s - m = re.search(r'\d+', str(s)) - return int(m.group()) if m else 0 - -def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]: - """Concatenate consecutive segments from the same speaker.""" - if not segments: - return [] - - # Get speaker number from first segment - first_speaker = extract_number(segments[0].speaker) - segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}] - - for segment in segments[1:]: - speaker = extract_number(segment.speaker) + 1 - if segments_concatenated[-1]['speaker'] != speaker: - segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end}) - else: - segments_concatenated[-1]['end'] = segment.end - - return segments_concatenated - -def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list: - """Assign speakers to tokens with punctuation-aware boundary adjustment.""" - punctuation_marks = {'.', '!', '?'} - punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks] - segments_concatenated = concatenate_speakers(segments) - - for ind, segment in enumerate(segments_concatenated): - for i, punctuation_token in enumerate(punctuation_tokens): - if punctuation_token.start > segment['end']: - after_length = punctuation_token.start - segment['end'] - before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf') - if before_length > after_length: - segment['end'] = punctuation_token.start - if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated): - segments_concatenated[ind + 1]['begin'] = punctuation_token.start - else: - segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end'] - if i < len(punctuation_tokens) - 1 and ind - 1 >= 0: - segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end - break - - # Ensure non-overlapping tokens - last_end = 0.0 - for token in tokens: - start = max(last_end + 0.01, token.start) - token.start = start - token.end = max(start, token.end) - last_end = token.end - - # Assign speakers based on adjusted segments - ind_last_speaker = 0 - for segment in segments_concatenated: - for i, token in enumerate(tokens[ind_last_speaker:]): - if token.end <= segment['end']: - token.speaker = segment['speaker'] - ind_last_speaker = i + 1 - elif token.start > segment['end']: - break - - return tokens - -def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list: - """ - Assign speakers to tokens based on timing overlap with speaker segments. - - Args: - tokens: List of tokens with timing information - segments: List of speaker segments - use_punctuation_split: Whether to use punctuation for boundary refinement - - Returns: - List of tokens with speaker assignments - """ - if not segments or not tokens: - logger.debug("No segments or tokens available for speaker assignment") - return tokens - - logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") - - if not use_punctuation_split: - # Simple overlap-based assignment - for token in tokens: - token.speaker = -1 # Default to no speaker - for segment in segments: - # Check for timing overlap - if not (segment.end <= token.start or segment.start >= token.end): - speaker_num = extract_number(segment.speaker) - token.speaker = speaker_num + 1 # Convert to 1-based indexing - break - else: - # Use punctuation-aware assignment - tokens = add_speaker_to_tokens_with_punctuation(segments, tokens) - - return tokens - -def format_output(state, silence, args, sep): - diarization = args.diarization - disable_punctuation_split = args.disable_punctuation_split - tokens = state.tokens - translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments - last_validated_token = state.last_validated_token - - last_speaker = abs(state.last_speaker) - undiarized_text = [] - tokens = handle_silences(tokens, state.beg_loop, silence) - - # Assign speakers to tokens based on segments stored in state - if False and diarization and state.diarization_segments: - use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False - tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split) - for i in range(last_validated_token, len(tokens)): - token = tokens[i] - speaker = int(token.speaker) - token.corrected_speaker = speaker - if True or not diarization: - if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1' - token.corrected_speaker = 1 - token.validated_speaker = True - else: - if token.speaker == -1: - undiarized_text.append(token.text) - elif token.is_punctuation(): - state.last_punctuation_index = i - token.corrected_speaker = last_speaker - token.validated_speaker = True - elif state.last_punctuation_index == i-1: - if token.speaker != last_speaker: - token.corrected_speaker = token.speaker - token.validated_speaker = True - # perfect, diarization perfectly aligned - else: - speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker) - if speaker_change_pos: - # Corrects delay: - # That was the idea. haha |SPLIT SPEAKER| that's a good one - # should become: - # That was the idea. |SPLIT SPEAKER| haha that's a good one - token.corrected_speaker = new_speaker - token.validated_speaker = True - elif speaker != last_speaker: - if not (speaker == -2 or last_speaker == -2): - if next_punctuation_change(i, tokens): - # Corrects advance: - # Are you |SPLIT SPEAKER| ? yeah, sure. Absolutely - # should become: - # Are you ? |SPLIT SPEAKER| yeah, sure. Absolutely - token.corrected_speaker = last_speaker - token.validated_speaker = True - else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True. - if not disable_punctuation_split: - token.corrected_speaker = last_speaker - token.validated_speaker = False - if token.validated_speaker: - state.last_validated_token = i - state.last_speaker = token.corrected_speaker - - last_speaker = 1 - - lines = [] - for token in tokens: - if token.corrected_speaker != -1: - if int(token.corrected_speaker) != int(last_speaker): - lines.append(new_line(token)) - else: - append_token_to_last_line(lines, sep, token) - - last_speaker = token.corrected_speaker - - if lines: - unassigned_translated_segments = [] - for ts in translation_validated_segments: - assigned = False - for line in lines: - if ts and ts.overlaps_with(line): - if ts.is_within(line): - line.translation += ts.text + ' ' - assigned = True - break - else: - ts0, ts1 = ts.approximate_cut_at(line.end) - if ts0 and line.overlaps_with(ts0): - line.translation += ts0.text + ' ' - if ts1: - unassigned_translated_segments.append(ts1) - assigned = True - break - if not assigned: - unassigned_translated_segments.append(ts) - - if unassigned_translated_segments: - for line in lines: - remaining_segments = [] - for ts in unassigned_translated_segments: - if ts and ts.overlaps_with(line): - line.translation += ts.text + ' ' - else: - remaining_segments.append(ts) - unassigned_translated_segments = remaining_segments #maybe do smth in the future about that - - if state.buffer_transcription and lines: - lines[-1].end = max(state.buffer_transcription.end, lines[-1].end) - - return lines, undiarized_text diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 1a04d4a..e0b7da7 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Optional, Any, List from datetime import timedelta +from typing import Union PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} @@ -39,7 +40,9 @@ class TimedText(Timed): def __bool__(self): return bool(self.text) - + + def __str__(self): + return str(self.text) @dataclass() class ASRToken(TimedText): @@ -53,6 +56,10 @@ class ASRToken(TimedText): """Return a new token with the time offset added.""" return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) + def is_silence(self): + return False + + @dataclass class Sentence(TimedText): pass @@ -134,6 +141,46 @@ class Silence(): return None self.duration = self.end - self.start + def is_silence(self): + return True + + +@dataclass +class Segment(): + start: Optional[float] + end: Optional[float] + text: Optional[str] + speaker: Optional[str] + + @classmethod + def from_tokens( + cls, + tokens: List[Union[ASRToken, Silence]], + is_silence=False + ) -> "Segment": + if not tokens: + return None + + start_token = tokens[0] + end_token = tokens[-1] + if is_silence: + return cls( + start=start_token.start, + end=end_token.end, + text=None, + speaker = -2 + ) + else: + return cls( + start=start_token.start, + end=end_token.end, + text=''.join(token.text for token in tokens), + speaker = -1 + ) + def is_silence(self): + return self.speaker == -2 + + @dataclass class Line(TimedText): translation: str = '' @@ -158,6 +205,13 @@ class Line(TimedText): self.speaker = 1 return self + def build_from_segment(self, segment: Segment): + self.text = segment.text + self.start = segment.start + self.end = segment.end + self.speaker = segment.speaker + return self + def is_silent(self) -> bool: return self.speaker == -2 @@ -193,47 +247,6 @@ class FrontData(): _dict['error'] = self.error return _dict -@dataclass -class PunctuationSegment(): - """Represents a segment of text between punctuation marks.""" - start: Optional[float] - end: Optional[float] - token_index_start: int - token_index_end: int - punctuation_token_index: int - punctuation_token: ASRToken - - @classmethod - def from_token_range( - cls, - tokens: List[ASRToken], - token_index_start: int, - token_index_end: int, - punctuation_token_index: int - ) -> "PunctuationSegment": - """Create a PunctuationSegment from a range of tokens ending at a punctuation mark.""" - if not tokens or token_index_start < 0 or token_index_end >= len(tokens): - raise ValueError("Invalid token indices") - - start_token = tokens[token_index_start] - end_token = tokens[token_index_end] - punctuation_token = tokens[punctuation_token_index] - - # Build text from tokens in the segment - segment_tokens = tokens[token_index_start:token_index_end + 1] - text = ''.join(token.text for token in segment_tokens) - - return cls( - start=start_token.start, - end=end_token.end, - text=text, - token_index_start=token_index_start, - token_index_end=token_index_end, - punctuation_token_index=punctuation_token_index, - punctuation_token=punctuation_token - ) - - @dataclass class ChangeSpeaker: speaker: int @@ -260,4 +273,5 @@ class StateLight(): new_tokens: list = field(default_factory=list) new_translation: list = field(default_factory=list) new_diarization: list = field(default_factory=list) - new_tokens_buffer: list = field(default_factory=list) #only when local agreement \ No newline at end of file + new_tokens_buffer: list = field(default_factory=list) #only when local agreement + new_translation_buffer: str = '' \ No newline at end of file diff --git a/whisperlivekit/tokens_alignment.py b/whisperlivekit/tokens_alignment.py new file mode 100644 index 0000000..246fba3 --- /dev/null +++ b/whisperlivekit/tokens_alignment.py @@ -0,0 +1,179 @@ +from time import time +from typing import Optional + +from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment + + +class TokensAlignment: + + def __init__(self, state, args, sep): + self.state = state + self.diarization = args.diarization + self._tokens_index = 0 + self._diarization_index = 0 + self._translation_index = 0 + + self.all_tokens : list[ASRToken] = [] + self.all_diarization_segments: list[SpeakerSegment] = [] + self.all_translation_segments = [] + + self.new_tokens : list[ASRToken] = [] + self.new_diarization: list[SpeakerSegment] = [] + self.new_translation = [] + self.new_translation_buffer = TimedText() + self.new_tokens_buffer = [] + self.sep = sep if sep is not None else ' ' + self.beg_loop = None + + def update(self): + self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] + self.new_diarization, self.state.new_diarization = self.state.new_diarization, [] + self.new_translation, self.state.new_translation = self.state.new_translation, [] + self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, [] + + self.all_tokens.extend(self.new_tokens) + self.all_diarization_segments.extend(self.new_diarization) + # self.all_translation_segments.extend(self.new_translation) #future + self.all_translation_segments = self.new_translation if self.new_translation != [] else self.all_translation_segments + self.new_translation_buffer = self.state.new_translation_buffer if self.new_translation else self.new_translation_buffer + self.new_translation_buffer = self.new_translation_buffer if type(self.new_translation_buffer) == str else self.new_translation_buffer.text + + def add_translation(self, line : Line): + + for ts in self.all_translation_segments: + if ts.is_within(line): + line.translation += ts.text + self.sep + elif line.translation: + break + + + def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None): + segments = [] + segment_start_idx = 0 + for i, token in enumerate(self.all_tokens): + if token.is_silence(): + previous_segment = Segment.from_tokens( + tokens=self.all_tokens[segment_start_idx: i], + ) + if previous_segment: + segments.append(previous_segment) + segment = Segment.from_tokens( + tokens=[token], + is_silence=True + ) + segments.append(segment) + segment_start_idx = i+1 + else: + if token.is_punctuation(): + segment = Segment.from_tokens( + tokens=self.all_tokens[segment_start_idx: i+1], + ) + segments.append(segment) + segment_start_idx = i+1 + + final_segment = Segment.from_tokens( + tokens=self.all_tokens[segment_start_idx:], + ) + if final_segment: + segments.append(final_segment) + return segments + + + def concatenate_diar_segments(self): + if not self.all_diarization_segments: + return [] + merged = [self.all_diarization_segments[0]] + for segment in self.all_diarization_segments[1:]: + if segment.speaker == merged[-1].speaker: + merged[-1].end = segment.end + else: + merged.append(segment) + return merged + + + @staticmethod + def intersection_duration(seg1, seg2): + start = max(seg1.start, seg2.start) + end = min(seg1.end, seg2.end) + + return max(0, end - start) + + def get_lines_diarization(self): + """ + use compute_punctuations_segments, concatenate_diar_segments, intersection_duration + """ + diarization_buffer = '' + punctuation_segments = self.compute_punctuations_segments() + diarization_segments = self.concatenate_diar_segments() + for punctuation_segment in punctuation_segments: + if not punctuation_segment.is_silence(): + if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end: + diarization_buffer += punctuation_segment.text + else: + max_overlap = 0.0 + max_overlap_speaker = 1 + for diarization_segment in diarization_segments: + intersec = self.intersection_duration(punctuation_segment, diarization_segment) + if intersec > max_overlap: + max_overlap = intersec + max_overlap_speaker = diarization_segment.speaker + 1 + punctuation_segment.speaker = max_overlap_speaker + + lines = [] + if punctuation_segments: + lines = [Line().build_from_segment(punctuation_segments[0])] + for segment in punctuation_segments[1:]: + if segment.speaker == lines[-1].speaker: + if lines[-1].text: + lines[-1].text += segment.text + lines[-1].end = segment.end + else: + lines.append(Line().build_from_segment(segment)) + + return lines, diarization_buffer + + + def get_lines( + self, + diarization=False, + translation=False, + current_silence=None + ): + """ + In the case without diarization + """ + if diarization: + lines, diarization_buffer = self.get_lines_diarization() + else: + diarization_buffer = '' + lines = [] + current_line_tokens = [] + for token in self.all_tokens: + if token.is_silence(): + if current_line_tokens: + lines.append(Line().build_from_tokens(current_line_tokens)) + current_line_tokens = [] + end_silence = token.end if token.has_ended else time() - self.beg_loop + if lines and lines[-1].is_silent(): + lines[-1].end = end_silence + else: + lines.append(SilentLine( + start = token.start, + end = end_silence + )) + else: + current_line_tokens.append(token) + if current_line_tokens: + lines.append(Line().build_from_tokens(current_line_tokens)) + if current_silence: + end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop + if lines and lines[-1].is_silent(): + lines[-1].end = end_silence + else: + lines.append(SilentLine( + start = current_silence.start, + end = end_silence + )) + if translation: + [self.add_translation(line) for line in lines if not type(line) == Silence] + return lines, diarization_buffer, self.new_translation_buffer diff --git a/whisperlivekit/trail_repetition.py b/whisperlivekit/trail_repetition.py deleted file mode 100644 index 18d9f5e..0000000 --- a/whisperlivekit/trail_repetition.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Sequence, Callable, Any, Optional, Dict - -def _detect_tail_repetition( - seq: Sequence[Any], - key: Callable[[Any], Any] = lambda x: x, # extract comparable value - min_block: int = 1, # set to 2 to ignore 1-token loops like "." - max_tail: int = 300, # search window from the end for speed - prefer: str = "longest", # "longest" coverage or "smallest" block -) -> Optional[Dict]: - vals = [key(x) for x in seq][-max_tail:] - n = len(vals) - best = None - - # try every possible block length - for b in range(min_block, n // 2 + 1): - block = vals[-b:] - # count how many times this block repeats contiguously at the very end - count, i = 0, n - while i - b >= 0 and vals[i - b:i] == block: - count += 1 - i -= b - - if count >= 2: - cand = { - "block_size": b, - "count": count, - "start_index": len(seq) - count * b, # in original seq - "end_index": len(seq), - } - if (best is None or - (prefer == "longest" and count * b > best["count"] * best["block_size"]) or - (prefer == "smallest" and b < best["block_size"])): - best = cand - return best - -def trim_tail_repetition( - seq: Sequence[Any], - key: Callable[[Any], Any] = lambda x: x, - min_block: int = 1, - max_tail: int = 300, - prefer: str = "longest", - keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common) -): - """ - Returns a new sequence with repeated tail trimmed. - keep=1 -> keep a single copy of the repeated block. - keep=0 -> remove all copies of the repeated block. - """ - rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer) - if not rep: - return seq, False # nothing to trim - - b, c = rep["block_size"], rep["count"] - if keep < 0: - keep = 0 - if keep >= c: - return seq, False # nothing to trim (already <= keep copies) - # new length = total - (copies_to_remove * block_size) - new_len = len(seq) - (c - keep) * b - return seq[:new_len], True \ No newline at end of file