stt/diar/nllw alignment: internal rework 5

This commit is contained in:
Quentin Fuxa
2025-11-20 23:52:00 +01:00
parent 8e7aea4fcf
commit 254faaf64c
8 changed files with 241 additions and 769 deletions

View File

@@ -1,240 +0,0 @@
from time import time
from typing import Optional
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence
from whisperlivekit.timed_objects import PunctuationSegment
ALIGNMENT_TIME_TOLERANCE = 0.2 # seconds
class TokensAlignment:
def __init__(self, state, args, sep):
self.state = state
self.diarization = args.diarization
self._tokens_index = 0
self._diarization_index = 0
self._translation_index = 0
self.all_tokens : list[ASRToken] = []
self.all_diarization_segments: list[SpeakerSegment] = []
self.all_translation_segments = []
self.new_tokens : list[ASRToken] = []
self.new_diarization: list[SpeakerSegment] = []
self.new_translation = []
self.new_tokens_buffer = []
self.sep = sep if sep is not None else ' '
self.beg_loop = None
def update(self):
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
def get_lines(self, current_silence):
"""
In the case without diarization
"""
lines = []
current_line_tokens = []
for token in self.all_tokens:
if type(token) == Silence:
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
return lines
def _get_asr_tokens(self) -> list[ASRToken]:
return [token for token in self.all_tokens if isinstance(token, ASRToken)]
def _tokens_to_text(self, tokens: list[ASRToken]) -> str:
return ''.join(token.text for token in tokens)
def _extract_detected_language(self, tokens: list[ASRToken]):
for token in tokens:
if getattr(token, 'detected_language', None):
return token.detected_language
return None
def _speaker_display_id(self, raw_speaker) -> int:
if isinstance(raw_speaker, int):
speaker_index = raw_speaker
else:
digits = ''.join(ch for ch in str(raw_speaker) if ch.isdigit())
speaker_index = int(digits) if digits else 0
return speaker_index + 1 if speaker_index >= 0 else 0
def _line_from_tokens(self, tokens: list[ASRToken], speaker: int) -> Line:
line = Line().build_from_tokens(tokens)
line.speaker = speaker
detected_language = self._extract_detected_language(tokens)
if detected_language:
line.detected_language = detected_language
return line
def _find_initial_diar_index(self, diar_segments: list[SpeakerSegment], start_time: float) -> int:
for idx, segment in enumerate(diar_segments):
if segment.end + ALIGNMENT_TIME_TOLERANCE >= start_time:
return idx
return len(diar_segments)
def _find_speaker_for_token(self, token: ASRToken, diar_segments: list[SpeakerSegment], diar_idx: int):
if not diar_segments:
return None, diar_idx
idx = min(diar_idx, len(diar_segments) - 1)
midpoint = (token.start + token.end) / 2 if token.end is not None else token.start
while idx < len(diar_segments) and diar_segments[idx].end + ALIGNMENT_TIME_TOLERANCE < midpoint:
idx += 1
candidate_indices = []
if idx < len(diar_segments):
candidate_indices.append(idx)
if idx > 0:
candidate_indices.append(idx - 1)
for candidate_idx in candidate_indices:
segment = diar_segments[candidate_idx]
seg_start = (segment.start or 0) - ALIGNMENT_TIME_TOLERANCE
seg_end = (segment.end or 0) + ALIGNMENT_TIME_TOLERANCE
if seg_start <= midpoint <= seg_end:
return segment.speaker, candidate_idx
return None, idx
def _build_lines_for_tokens(self, tokens: list[ASRToken], diar_segments: list[SpeakerSegment], diar_idx: int):
if not tokens:
return [], diar_idx
segment_lines: list[Line] = []
current_tokens: list[ASRToken] = []
current_speaker = None
pointer = diar_idx
for token in tokens:
speaker_raw, pointer = self._find_speaker_for_token(token, diar_segments, pointer)
if speaker_raw is None:
return [], diar_idx
speaker = self._speaker_display_id(speaker_raw)
if current_speaker is None or current_speaker != speaker:
if current_tokens:
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
current_tokens = [token]
current_speaker = speaker
else:
current_tokens.append(token)
if current_tokens:
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
return segment_lines, pointer
def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None):
"""Compute segments of text between punctuation marks.
Returns a list of PunctuationSegment objects, each representing
the text from the start (or previous punctuation) to the current punctuation mark.
"""
tokens = tokens if tokens is not None else self._get_asr_tokens()
if not tokens:
return []
punctuation_indices = [
i for i, token in enumerate[ASRToken](tokens)
if token.is_punctuation()
]
if not punctuation_indices:
return []
segments = []
for i, punct_idx in enumerate(punctuation_indices):
start_idx = punctuation_indices[i - 1] + 1 if i > 0 else 0
end_idx = punct_idx
if start_idx <= end_idx:
segment = PunctuationSegment.from_token_range(
tokens=tokens,
token_index_start=start_idx,
token_index_end=end_idx,
punctuation_token_index=punct_idx
)
segments.append(segment)
return segments
def concatenate_diar_segments(self):
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
def get_lines(self, diarization=False, translation=False):
"""
Align diarization speaker segments with punctuation-delimited transcription
segments (see docs/alignement_principles.md).
"""
tokens = self._get_asr_tokens()
if not tokens:
return [], ''
punctuation_segments = self.compute_punctuations_segments(tokens=tokens)
diar_segments = self.concatenate_diar_segments()
if not punctuation_segments or not diar_segments:
return [], self._tokens_to_text(tokens)
max_diar_end = diar_segments[-1].end
if max_diar_end is None:
return [], self._tokens_to_text(tokens)
lines: list[Line] = []
last_consumed_index = -1
diar_idx = self._find_initial_diar_index(diar_segments, tokens[0].start or 0)
for segment in punctuation_segments:
if segment.end is None or segment.end > max_diar_end:
break
slice_tokens = tokens[segment.token_index_start:segment.token_index_end + 1]
segment_lines, diar_idx = self._build_lines_for_tokens(slice_tokens, diar_segments, diar_idx)
if not segment_lines:
break
lines.extend(segment_lines)
last_consumed_index = segment.token_index_end
buffer_tokens = tokens[last_consumed_index + 1:] if last_consumed_index + 1 < len(tokens) else []
buffer_diarization = self._tokens_to_text(buffer_tokens)
return lines, buffer_diarization

View File

@@ -7,9 +7,8 @@ import traceback
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.results_formater import format_output
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.TokensAlignment import TokensAlignment
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -392,8 +391,8 @@ class AudioProcessor:
self.translation.insert_tokens(tokens_to_process)
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
async with self.lock:
self.state.translation_validated_segments = translation_validated_segments
self.state.buffer_translation = buffer_translation
self.state_light.new_translation = translation_validated_segments
self.state_light.new_translation_buffer = buffer_translation
except Exception as e:
logger.warning(f"Exception in translation_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
@@ -412,11 +411,11 @@ class AudioProcessor:
self.tokens_alignment.update()
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
diarization=self.args.diarization,
translation=self.args.translation
translation=bool(self.translation),
current_silence=self.current_silence
)
state = await self.get_current_state()
buffer_translation_text = ''
buffer_transcription_text = ''
buffer_diarization_text = ''

View File

@@ -1,103 +0,0 @@
from whisperlivekit.timed_objects import ASRToken
from time import time
import re
MIN_SILENCE_DURATION = 4 #in seconds
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
def blank_to_silence(tokens):
full_string = ''.join([t.text for t in tokens])
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
matches = []
for pattern in patterns:
for m in pattern.finditer(full_string):
matches.append({
'start': m.start(),
'end': m.end()
})
if matches:
# cleaned = pattern.sub(' ', full_string).strip()
# print("Cleaned:", cleaned)
cumulated_len = 0
silence_token = None
cleaned_tokens = []
for token in tokens:
if matches:
start = cumulated_len
end = cumulated_len + len(token.text)
cumulated_len = end
if start >= matches[0]['start'] and end <= matches[0]['end']:
if silence_token: #previous token was already silence
silence_token.start = min(silence_token.start, token.start)
silence_token.end = max(silence_token.end, token.end)
else: #new silence
silence_token = ASRToken(
start=token.start,
end=token.end,
speaker=-2,
)
else:
if silence_token: #there was silence but no more
if silence_token.duration() >= MIN_SILENCE_DURATION:
cleaned_tokens.append(
silence_token
)
silence_token = None
matches.pop(0)
cleaned_tokens.append(token)
# print(cleaned_tokens)
return cleaned_tokens
return tokens
def no_token_to_silence(tokens):
new_tokens = []
silence_token = None
for token in tokens:
if token.speaker == -2:
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
new_tokens[-1].end = token.end
else:
new_tokens.append(token)
last_end = new_tokens[-1].end if new_tokens else 0.0
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
if new_tokens and new_tokens[-1].speaker == -2:
new_tokens[-1].end = token.start
else:
silence_token = ASRToken(
start=last_end,
end=token.start,
speaker=-2,
)
new_tokens.append(silence_token)
if token.speaker != -2:
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
if last_token.speaker == -2:
last_token.end = current_time
else:
tokens.append(
ASRToken(
start=tokens[-1].end,
end=current_time,
speaker=-2,
)
)
return tokens
def handle_silences(tokens, beg_loop, vac_detected_silence):
if not tokens:
return []
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens

View File

@@ -1,60 +0,0 @@
########### WHAT IS PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:06
Transcription technology has improved so much in the past
SPEAKER 1 0:00:07 - 0:00:12
years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:12
Absolutely
SPEAKER 1 0:00:13 - 0:00:13
.
SPEAKER 2 0:00:14 - 0:00:14
I
SPEAKER 1 0:00:14 - 0:00:17
use it all the time for taking notes during meetings.
SPEAKER 2 0:00:17 - 0:00:17
It
SPEAKER 1 0:00:17 - 0:00:22
's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 2 0:00:22 - 0:00:22
Yeah
SPEAKER 1 0:00:23 - 0:00:26
, but sometimes noise can still cause mistakes.
SPEAKER 3 0:00:26 - 0:00:27
Does
SPEAKER 1 0:00:27 - 0:00:28
this system handle that
SPEAKER 1 0:00:29 - 0:00:29
?
SPEAKER 3 0:00:29 - 0:00:29
It
SPEAKER 1 0:00:29 - 0:00:33
does a pretty good job filtering noise, especially with models that use voice activity
########### WHAT SHOULD BE PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:12
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:22
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 3 0:00:22 - 0:00:28
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
SPEAKER 1 0:00:29 - 0:00:29
It does a pretty good job filtering noise, especially with models that use voice activity

View File

@@ -1,257 +0,0 @@
import logging
import re
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
from typing import List
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
CHECK_AROUND = 4
DEBUG = False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if tokens[ind].is_punctuation():
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if token.is_punctuation():
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
):
return Line(
speaker = token.corrected_speaker,
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
start = token.start,
end = token.end,
detected_language=token.detected_language
)
def append_token_to_last_line(lines, sep, token):
if not lines:
lines.append(new_line(token))
else:
if token.text:
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def extract_number(s) -> int:
"""Extract number from speaker string (for diart compatibility)."""
if isinstance(s, int):
return s
m = re.search(r'\d+', str(s))
return int(m.group()) if m else 0
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
"""Concatenate consecutive segments from the same speaker."""
if not segments:
return []
# Get speaker number from first segment
first_speaker = extract_number(segments[0].speaker)
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
segments: List of speaker segments
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
speaker_num = extract_number(segment.speaker)
token.speaker = speaker_num + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def format_output(state, silence, args, sep):
diarization = args.diarization
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
last_validated_token = state.last_validated_token
last_speaker = abs(state.last_speaker)
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
# Assign speakers to tokens based on segments stored in state
if False and diarization and state.diarization_segments:
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split)
for i in range(last_validated_token, len(tokens)):
token = tokens[i]
speaker = int(token.speaker)
token.corrected_speaker = speaker
if True or not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
token.corrected_speaker = 1
token.validated_speaker = True
else:
if token.speaker == -1:
undiarized_text.append(token.text)
elif token.is_punctuation():
state.last_punctuation_index = i
token.corrected_speaker = last_speaker
token.validated_speaker = True
elif state.last_punctuation_index == i-1:
if token.speaker != last_speaker:
token.corrected_speaker = token.speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != last_speaker:
if not (speaker == -2 or last_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = last_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = last_speaker
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
state.last_speaker = token.corrected_speaker
last_speaker = 1
lines = []
for token in tokens:
if token.corrected_speaker != -1:
if int(token.corrected_speaker) != int(last_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
last_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Optional, Any, List
from datetime import timedelta
from typing import Union
PUNCTUATION_MARKS = {'.', '!', '?', '', '', ''}
@@ -39,7 +40,9 @@ class TimedText(Timed):
def __bool__(self):
return bool(self.text)
def __str__(self):
return str(self.text)
@dataclass()
class ASRToken(TimedText):
@@ -53,6 +56,10 @@ class ASRToken(TimedText):
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
def is_silence(self):
return False
@dataclass
class Sentence(TimedText):
pass
@@ -134,6 +141,46 @@ class Silence():
return None
self.duration = self.end - self.start
def is_silence(self):
return True
@dataclass
class Segment():
start: Optional[float]
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
@classmethod
def from_tokens(
cls,
tokens: List[Union[ASRToken, Silence]],
is_silence=False
) -> "Segment":
if not tokens:
return None
start_token = tokens[0]
end_token = tokens[-1]
if is_silence:
return cls(
start=start_token.start,
end=end_token.end,
text=None,
speaker = -2
)
else:
return cls(
start=start_token.start,
end=end_token.end,
text=''.join(token.text for token in tokens),
speaker = -1
)
def is_silence(self):
return self.speaker == -2
@dataclass
class Line(TimedText):
translation: str = ''
@@ -158,6 +205,13 @@ class Line(TimedText):
self.speaker = 1
return self
def build_from_segment(self, segment: Segment):
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
return self
def is_silent(self) -> bool:
return self.speaker == -2
@@ -193,47 +247,6 @@ class FrontData():
_dict['error'] = self.error
return _dict
@dataclass
class PunctuationSegment():
"""Represents a segment of text between punctuation marks."""
start: Optional[float]
end: Optional[float]
token_index_start: int
token_index_end: int
punctuation_token_index: int
punctuation_token: ASRToken
@classmethod
def from_token_range(
cls,
tokens: List[ASRToken],
token_index_start: int,
token_index_end: int,
punctuation_token_index: int
) -> "PunctuationSegment":
"""Create a PunctuationSegment from a range of tokens ending at a punctuation mark."""
if not tokens or token_index_start < 0 or token_index_end >= len(tokens):
raise ValueError("Invalid token indices")
start_token = tokens[token_index_start]
end_token = tokens[token_index_end]
punctuation_token = tokens[punctuation_token_index]
# Build text from tokens in the segment
segment_tokens = tokens[token_index_start:token_index_end + 1]
text = ''.join(token.text for token in segment_tokens)
return cls(
start=start_token.start,
end=end_token.end,
text=text,
token_index_start=token_index_start,
token_index_end=token_index_end,
punctuation_token_index=punctuation_token_index,
punctuation_token=punctuation_token
)
@dataclass
class ChangeSpeaker:
speaker: int
@@ -260,4 +273,5 @@ class StateLight():
new_tokens: list = field(default_factory=list)
new_translation: list = field(default_factory=list)
new_diarization: list = field(default_factory=list)
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
new_translation_buffer: str = ''

View File

@@ -0,0 +1,179 @@
from time import time
from typing import Optional
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
class TokensAlignment:
def __init__(self, state, args, sep):
self.state = state
self.diarization = args.diarization
self._tokens_index = 0
self._diarization_index = 0
self._translation_index = 0
self.all_tokens : list[ASRToken] = []
self.all_diarization_segments: list[SpeakerSegment] = []
self.all_translation_segments = []
self.new_tokens : list[ASRToken] = []
self.new_diarization: list[SpeakerSegment] = []
self.new_translation = []
self.new_translation_buffer = TimedText()
self.new_tokens_buffer = []
self.sep = sep if sep is not None else ' '
self.beg_loop = None
def update(self):
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
self.new_translation, self.state.new_translation = self.state.new_translation, []
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
self.all_tokens.extend(self.new_tokens)
self.all_diarization_segments.extend(self.new_diarization)
# self.all_translation_segments.extend(self.new_translation) #future
self.all_translation_segments = self.new_translation if self.new_translation != [] else self.all_translation_segments
self.new_translation_buffer = self.state.new_translation_buffer if self.new_translation else self.new_translation_buffer
self.new_translation_buffer = self.new_translation_buffer if type(self.new_translation_buffer) == str else self.new_translation_buffer.text
def add_translation(self, line : Line):
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + self.sep
elif line.translation:
break
def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None):
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
tokens=[token],
is_silence=True
)
segments.append(segment)
segment_start_idx = i+1
else:
if token.is_punctuation():
segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def concatenate_diar_segments(self):
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
@staticmethod
def intersection_duration(seg1, seg2):
start = max(seg1.start, seg2.start)
end = min(seg1.end, seg2.end)
return max(0, end - start)
def get_lines_diarization(self):
"""
use compute_punctuations_segments, concatenate_diar_segments, intersection_duration
"""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
for punctuation_segment in punctuation_segments:
if not punctuation_segment.is_silence():
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
diarization_buffer += punctuation_segment.text
else:
max_overlap = 0.0
max_overlap_speaker = 1
for diarization_segment in diarization_segments:
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
if intersec > max_overlap:
max_overlap = intersec
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
lines = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
else:
lines.append(Line().build_from_segment(segment))
return lines, diarization_buffer
def get_lines(
self,
diarization=False,
translation=False,
current_silence=None
):
"""
In the case without diarization
"""
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer

View File

@@ -1,60 +0,0 @@
from typing import Sequence, Callable, Any, Optional, Dict
def _detect_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
max_tail: int = 300, # search window from the end for speed
prefer: str = "longest", # "longest" coverage or "smallest" block
) -> Optional[Dict]:
vals = [key(x) for x in seq][-max_tail:]
n = len(vals)
best = None
# try every possible block length
for b in range(min_block, n // 2 + 1):
block = vals[-b:]
# count how many times this block repeats contiguously at the very end
count, i = 0, n
while i - b >= 0 and vals[i - b:i] == block:
count += 1
i -= b
if count >= 2:
cand = {
"block_size": b,
"count": count,
"start_index": len(seq) - count * b, # in original seq
"end_index": len(seq),
}
if (best is None or
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
(prefer == "smallest" and b < best["block_size"])):
best = cand
return best
def trim_tail_repetition(
seq: Sequence[Any],
key: Callable[[Any], Any] = lambda x: x,
min_block: int = 1,
max_tail: int = 300,
prefer: str = "longest",
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
):
"""
Returns a new sequence with repeated tail trimmed.
keep=1 -> keep a single copy of the repeated block.
keep=0 -> remove all copies of the repeated block.
"""
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
if not rep:
return seq, False # nothing to trim
b, c = rep["block_size"], rep["count"]
if keep < 0:
keep = 0
if keep >= c:
return seq, False # nothing to trim (already <= keep copies)
# new length = total - (copies_to_remove * block_size)
new_len = len(seq) - (c - keep) * b
return seq[:new_len], True