mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
stt/diar/nllw alignment: internal rework 5
This commit is contained in:
@@ -1,240 +0,0 @@
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence
|
||||
from whisperlivekit.timed_objects import PunctuationSegment
|
||||
|
||||
ALIGNMENT_TIME_TOLERANCE = 0.2 # seconds
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state, args, sep):
|
||||
self.state = state
|
||||
self.diarization = args.diarization
|
||||
self._tokens_index = 0
|
||||
self._diarization_index = 0
|
||||
self._translation_index = 0
|
||||
|
||||
self.all_tokens : list[ASRToken] = []
|
||||
self.all_diarization_segments: list[SpeakerSegment] = []
|
||||
self.all_translation_segments = []
|
||||
|
||||
self.new_tokens : list[ASRToken] = []
|
||||
self.new_diarization: list[SpeakerSegment] = []
|
||||
self.new_translation = []
|
||||
self.new_tokens_buffer = []
|
||||
self.sep = sep if sep is not None else ' '
|
||||
self.beg_loop = None
|
||||
|
||||
def update(self):
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
|
||||
self.new_translation, self.state.new_translation = self.state.new_translation, []
|
||||
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
|
||||
|
||||
self.all_tokens.extend(self.new_tokens)
|
||||
self.all_diarization_segments.extend(self.new_diarization)
|
||||
self.all_translation_segments.extend(self.new_translation)
|
||||
|
||||
def get_lines(self, current_silence):
|
||||
"""
|
||||
In the case without diarization
|
||||
"""
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
if type(token) == Silence:
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
))
|
||||
else:
|
||||
current_line_tokens.append(token)
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
))
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _get_asr_tokens(self) -> list[ASRToken]:
|
||||
return [token for token in self.all_tokens if isinstance(token, ASRToken)]
|
||||
|
||||
def _tokens_to_text(self, tokens: list[ASRToken]) -> str:
|
||||
return ''.join(token.text for token in tokens)
|
||||
|
||||
def _extract_detected_language(self, tokens: list[ASRToken]):
|
||||
for token in tokens:
|
||||
if getattr(token, 'detected_language', None):
|
||||
return token.detected_language
|
||||
return None
|
||||
|
||||
def _speaker_display_id(self, raw_speaker) -> int:
|
||||
if isinstance(raw_speaker, int):
|
||||
speaker_index = raw_speaker
|
||||
else:
|
||||
digits = ''.join(ch for ch in str(raw_speaker) if ch.isdigit())
|
||||
speaker_index = int(digits) if digits else 0
|
||||
return speaker_index + 1 if speaker_index >= 0 else 0
|
||||
|
||||
def _line_from_tokens(self, tokens: list[ASRToken], speaker: int) -> Line:
|
||||
line = Line().build_from_tokens(tokens)
|
||||
line.speaker = speaker
|
||||
detected_language = self._extract_detected_language(tokens)
|
||||
if detected_language:
|
||||
line.detected_language = detected_language
|
||||
return line
|
||||
|
||||
def _find_initial_diar_index(self, diar_segments: list[SpeakerSegment], start_time: float) -> int:
|
||||
for idx, segment in enumerate(diar_segments):
|
||||
if segment.end + ALIGNMENT_TIME_TOLERANCE >= start_time:
|
||||
return idx
|
||||
return len(diar_segments)
|
||||
|
||||
def _find_speaker_for_token(self, token: ASRToken, diar_segments: list[SpeakerSegment], diar_idx: int):
|
||||
if not diar_segments:
|
||||
return None, diar_idx
|
||||
idx = min(diar_idx, len(diar_segments) - 1)
|
||||
midpoint = (token.start + token.end) / 2 if token.end is not None else token.start
|
||||
|
||||
while idx < len(diar_segments) and diar_segments[idx].end + ALIGNMENT_TIME_TOLERANCE < midpoint:
|
||||
idx += 1
|
||||
|
||||
candidate_indices = []
|
||||
if idx < len(diar_segments):
|
||||
candidate_indices.append(idx)
|
||||
if idx > 0:
|
||||
candidate_indices.append(idx - 1)
|
||||
|
||||
for candidate_idx in candidate_indices:
|
||||
segment = diar_segments[candidate_idx]
|
||||
seg_start = (segment.start or 0) - ALIGNMENT_TIME_TOLERANCE
|
||||
seg_end = (segment.end or 0) + ALIGNMENT_TIME_TOLERANCE
|
||||
if seg_start <= midpoint <= seg_end:
|
||||
return segment.speaker, candidate_idx
|
||||
|
||||
return None, idx
|
||||
|
||||
def _build_lines_for_tokens(self, tokens: list[ASRToken], diar_segments: list[SpeakerSegment], diar_idx: int):
|
||||
if not tokens:
|
||||
return [], diar_idx
|
||||
|
||||
segment_lines: list[Line] = []
|
||||
current_tokens: list[ASRToken] = []
|
||||
current_speaker = None
|
||||
pointer = diar_idx
|
||||
|
||||
for token in tokens:
|
||||
speaker_raw, pointer = self._find_speaker_for_token(token, diar_segments, pointer)
|
||||
if speaker_raw is None:
|
||||
return [], diar_idx
|
||||
speaker = self._speaker_display_id(speaker_raw)
|
||||
if current_speaker is None or current_speaker != speaker:
|
||||
if current_tokens:
|
||||
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
|
||||
current_tokens = [token]
|
||||
current_speaker = speaker
|
||||
else:
|
||||
current_tokens.append(token)
|
||||
|
||||
if current_tokens:
|
||||
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
|
||||
|
||||
return segment_lines, pointer
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None):
|
||||
"""Compute segments of text between punctuation marks.
|
||||
|
||||
Returns a list of PunctuationSegment objects, each representing
|
||||
the text from the start (or previous punctuation) to the current punctuation mark.
|
||||
"""
|
||||
|
||||
tokens = tokens if tokens is not None else self._get_asr_tokens()
|
||||
if not tokens:
|
||||
return []
|
||||
punctuation_indices = [
|
||||
i for i, token in enumerate[ASRToken](tokens)
|
||||
if token.is_punctuation()
|
||||
]
|
||||
if not punctuation_indices:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i, punct_idx in enumerate(punctuation_indices):
|
||||
start_idx = punctuation_indices[i - 1] + 1 if i > 0 else 0
|
||||
end_idx = punct_idx
|
||||
if start_idx <= end_idx:
|
||||
segment = PunctuationSegment.from_token_range(
|
||||
tokens=tokens,
|
||||
token_index_start=start_idx,
|
||||
token_index_end=end_idx,
|
||||
punctuation_token_index=punct_idx
|
||||
)
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self):
|
||||
if not self.all_diarization_segments:
|
||||
return []
|
||||
merged = [self.all_diarization_segments[0]]
|
||||
for segment in self.all_diarization_segments[1:]:
|
||||
if segment.speaker == merged[-1].speaker:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
|
||||
def get_lines(self, diarization=False, translation=False):
|
||||
"""
|
||||
Align diarization speaker segments with punctuation-delimited transcription
|
||||
segments (see docs/alignement_principles.md).
|
||||
"""
|
||||
tokens = self._get_asr_tokens()
|
||||
if not tokens:
|
||||
return [], ''
|
||||
|
||||
punctuation_segments = self.compute_punctuations_segments(tokens=tokens)
|
||||
diar_segments = self.concatenate_diar_segments()
|
||||
|
||||
if not punctuation_segments or not diar_segments:
|
||||
return [], self._tokens_to_text(tokens)
|
||||
|
||||
max_diar_end = diar_segments[-1].end
|
||||
if max_diar_end is None:
|
||||
return [], self._tokens_to_text(tokens)
|
||||
|
||||
lines: list[Line] = []
|
||||
last_consumed_index = -1
|
||||
diar_idx = self._find_initial_diar_index(diar_segments, tokens[0].start or 0)
|
||||
|
||||
for segment in punctuation_segments:
|
||||
if segment.end is None or segment.end > max_diar_end:
|
||||
break
|
||||
slice_tokens = tokens[segment.token_index_start:segment.token_index_end + 1]
|
||||
segment_lines, diar_idx = self._build_lines_for_tokens(slice_tokens, diar_segments, diar_idx)
|
||||
if not segment_lines:
|
||||
break
|
||||
lines.extend(segment_lines)
|
||||
last_consumed_index = segment.token_index_end
|
||||
|
||||
buffer_tokens = tokens[last_consumed_index + 1:] if last_consumed_index + 1 < len(tokens) else []
|
||||
buffer_diarization = self._tokens_to_text(buffer_tokens)
|
||||
|
||||
return lines, buffer_diarization
|
||||
@@ -7,9 +7,8 @@ import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.TokensAlignment import TokensAlignment
|
||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@@ -392,8 +391,8 @@ class AudioProcessor:
|
||||
self.translation.insert_tokens(tokens_to_process)
|
||||
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.translation_validated_segments = translation_validated_segments
|
||||
self.state.buffer_translation = buffer_translation
|
||||
self.state_light.new_translation = translation_validated_segments
|
||||
self.state_light.new_translation_buffer = buffer_translation
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
@@ -412,11 +411,11 @@ class AudioProcessor:
|
||||
self.tokens_alignment.update()
|
||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||
diarization=self.args.diarization,
|
||||
translation=self.args.translation
|
||||
translation=bool(self.translation),
|
||||
current_silence=self.current_silence
|
||||
)
|
||||
state = await self.get_current_state()
|
||||
|
||||
buffer_translation_text = ''
|
||||
buffer_transcription_text = ''
|
||||
buffer_diarization_text = ''
|
||||
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from time import time
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||
|
||||
def blank_to_silence(tokens):
|
||||
full_string = ''.join([t.text for t in tokens])
|
||||
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
for m in pattern.finditer(full_string):
|
||||
matches.append({
|
||||
'start': m.start(),
|
||||
'end': m.end()
|
||||
})
|
||||
if matches:
|
||||
# cleaned = pattern.sub(' ', full_string).strip()
|
||||
# print("Cleaned:", cleaned)
|
||||
cumulated_len = 0
|
||||
silence_token = None
|
||||
cleaned_tokens = []
|
||||
for token in tokens:
|
||||
if matches:
|
||||
start = cumulated_len
|
||||
end = cumulated_len + len(token.text)
|
||||
cumulated_len = end
|
||||
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||
if silence_token: #previous token was already silence
|
||||
silence_token.start = min(silence_token.start, token.start)
|
||||
silence_token.end = max(silence_token.end, token.end)
|
||||
else: #new silence
|
||||
silence_token = ASRToken(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
speaker=-2,
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
silence_token = None
|
||||
matches.pop(0)
|
||||
cleaned_tokens.append(token)
|
||||
# print(cleaned_tokens)
|
||||
return cleaned_tokens
|
||||
return tokens
|
||||
|
||||
def no_token_to_silence(tokens):
|
||||
new_tokens = []
|
||||
silence_token = None
|
||||
for token in tokens:
|
||||
if token.speaker == -2:
|
||||
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||
new_tokens[-1].end = token.end
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
|
||||
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||
if new_tokens and new_tokens[-1].speaker == -2:
|
||||
new_tokens[-1].end = token.start
|
||||
else:
|
||||
silence_token = ASRToken(
|
||||
start=last_end,
|
||||
end=token.start,
|
||||
speaker=-2,
|
||||
)
|
||||
new_tokens.append(silence_token)
|
||||
|
||||
if token.speaker != -2:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||
last_token = tokens[-1]
|
||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start=tokens[-1].end,
|
||||
end=current_time,
|
||||
speaker=-2,
|
||||
)
|
||||
)
|
||||
return tokens
|
||||
|
||||
|
||||
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||
if not tokens:
|
||||
return []
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||
return tokens
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
########### WHAT IS PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:06
|
||||
Transcription technology has improved so much in the past
|
||||
|
||||
SPEAKER 1 0:00:07 - 0:00:12
|
||||
years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:12
|
||||
Absolutely
|
||||
|
||||
SPEAKER 1 0:00:13 - 0:00:13
|
||||
.
|
||||
|
||||
SPEAKER 2 0:00:14 - 0:00:14
|
||||
I
|
||||
|
||||
SPEAKER 1 0:00:14 - 0:00:17
|
||||
use it all the time for taking notes during meetings.
|
||||
|
||||
SPEAKER 2 0:00:17 - 0:00:17
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:17 - 0:00:22
|
||||
's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 2 0:00:22 - 0:00:22
|
||||
Yeah
|
||||
|
||||
SPEAKER 1 0:00:23 - 0:00:26
|
||||
, but sometimes noise can still cause mistakes.
|
||||
|
||||
SPEAKER 3 0:00:26 - 0:00:27
|
||||
Does
|
||||
|
||||
SPEAKER 1 0:00:27 - 0:00:28
|
||||
this system handle that
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
?
|
||||
|
||||
SPEAKER 3 0:00:29 - 0:00:29
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:33
|
||||
does a pretty good job filtering noise, especially with models that use voice activity
|
||||
|
||||
########### WHAT SHOULD BE PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:12
|
||||
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:22
|
||||
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 3 0:00:22 - 0:00:28
|
||||
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
It does a pretty good job filtering noise, especially with models that use voice activity
|
||||
@@ -1,257 +0,0 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
CHECK_AROUND = 4
|
||||
DEBUG = False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if tokens[ind].is_punctuation():
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if token.is_punctuation():
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
):
|
||||
return Line(
|
||||
speaker = token.corrected_speaker,
|
||||
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token):
|
||||
if not lines:
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
def extract_number(s) -> int:
|
||||
"""Extract number from speaker string (for diart compatibility)."""
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
m = re.search(r'\d+', str(s))
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""Concatenate consecutive segments from the same speaker."""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Get speaker number from first segment
|
||||
first_speaker = extract_number(segments[0].speaker)
|
||||
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
segments: List of speaker segments
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
speaker_num = extract_number(segment.speaker)
|
||||
token.speaker = speaker_num + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment
|
||||
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def format_output(state, silence, args, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
last_speaker = abs(state.last_speaker)
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
|
||||
# Assign speakers to tokens based on segments stored in state
|
||||
if False and diarization and state.diarization_segments:
|
||||
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
|
||||
tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
token = tokens[i]
|
||||
speaker = int(token.speaker)
|
||||
token.corrected_speaker = speaker
|
||||
if True or not diarization:
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
if token.speaker == -1:
|
||||
undiarized_text.append(token.text)
|
||||
elif token.is_punctuation():
|
||||
state.last_punctuation_index = i
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
elif state.last_punctuation_index == i-1:
|
||||
if token.speaker != last_speaker:
|
||||
token.corrected_speaker = token.speaker
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
elif speaker != last_speaker:
|
||||
if not (speaker == -2 or last_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = False
|
||||
if token.validated_speaker:
|
||||
state.last_validated_token = i
|
||||
state.last_speaker = token.corrected_speaker
|
||||
|
||||
last_speaker = 1
|
||||
|
||||
lines = []
|
||||
for token in tokens:
|
||||
if token.corrected_speaker != -1:
|
||||
if int(token.corrected_speaker) != int(last_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
|
||||
last_speaker = token.corrected_speaker
|
||||
|
||||
if lines:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translation_validated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + ' '
|
||||
assigned = True
|
||||
break
|
||||
else:
|
||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||
if ts0 and line.overlaps_with(ts0):
|
||||
line.translation += ts0.text + ' '
|
||||
if ts1:
|
||||
unassigned_translated_segments.append(ts1)
|
||||
assigned = True
|
||||
break
|
||||
if not assigned:
|
||||
unassigned_translated_segments.append(ts)
|
||||
|
||||
if unassigned_translated_segments:
|
||||
for line in lines:
|
||||
remaining_segments = []
|
||||
for ts in unassigned_translated_segments:
|
||||
if ts and ts.overlaps_with(line):
|
||||
line.translation += ts.text + ' '
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
|
||||
if state.buffer_transcription and lines:
|
||||
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||
|
||||
return lines, undiarized_text
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any, List
|
||||
from datetime import timedelta
|
||||
from typing import Union
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
@@ -39,7 +40,9 @@ class TimedText(Timed):
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return str(self.text)
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
@@ -53,6 +56,10 @@ class ASRToken(TimedText):
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
|
||||
def is_silence(self):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
pass
|
||||
@@ -134,6 +141,46 @@ class Silence():
|
||||
return None
|
||||
self.duration = self.end - self.start
|
||||
|
||||
def is_silence(self):
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment():
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
text: Optional[str]
|
||||
speaker: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[Union[ASRToken, Silence]],
|
||||
is_silence=False
|
||||
) -> "Segment":
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
start_token = tokens[0]
|
||||
end_token = tokens[-1]
|
||||
if is_silence:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=None,
|
||||
speaker = -2
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=''.join(token.text for token in tokens),
|
||||
speaker = -1
|
||||
)
|
||||
def is_silence(self):
|
||||
return self.speaker == -2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
@@ -158,6 +205,13 @@ class Line(TimedText):
|
||||
self.speaker = 1
|
||||
return self
|
||||
|
||||
def build_from_segment(self, segment: Segment):
|
||||
self.text = segment.text
|
||||
self.start = segment.start
|
||||
self.end = segment.end
|
||||
self.speaker = segment.speaker
|
||||
return self
|
||||
|
||||
def is_silent(self) -> bool:
|
||||
return self.speaker == -2
|
||||
|
||||
@@ -193,47 +247,6 @@ class FrontData():
|
||||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
class PunctuationSegment():
|
||||
"""Represents a segment of text between punctuation marks."""
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
token_index_start: int
|
||||
token_index_end: int
|
||||
punctuation_token_index: int
|
||||
punctuation_token: ASRToken
|
||||
|
||||
@classmethod
|
||||
def from_token_range(
|
||||
cls,
|
||||
tokens: List[ASRToken],
|
||||
token_index_start: int,
|
||||
token_index_end: int,
|
||||
punctuation_token_index: int
|
||||
) -> "PunctuationSegment":
|
||||
"""Create a PunctuationSegment from a range of tokens ending at a punctuation mark."""
|
||||
if not tokens or token_index_start < 0 or token_index_end >= len(tokens):
|
||||
raise ValueError("Invalid token indices")
|
||||
|
||||
start_token = tokens[token_index_start]
|
||||
end_token = tokens[token_index_end]
|
||||
punctuation_token = tokens[punctuation_token_index]
|
||||
|
||||
# Build text from tokens in the segment
|
||||
segment_tokens = tokens[token_index_start:token_index_end + 1]
|
||||
text = ''.join(token.text for token in segment_tokens)
|
||||
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=text,
|
||||
token_index_start=token_index_start,
|
||||
token_index_end=token_index_end,
|
||||
punctuation_token_index=punctuation_token_index,
|
||||
punctuation_token=punctuation_token
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
@@ -260,4 +273,5 @@ class StateLight():
|
||||
new_tokens: list = field(default_factory=list)
|
||||
new_translation: list = field(default_factory=list)
|
||||
new_diarization: list = field(default_factory=list)
|
||||
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
|
||||
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
|
||||
new_translation_buffer: str = ''
|
||||
179
whisperlivekit/tokens_alignment.py
Normal file
179
whisperlivekit/tokens_alignment.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence, TimedText, Segment
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state, args, sep):
|
||||
self.state = state
|
||||
self.diarization = args.diarization
|
||||
self._tokens_index = 0
|
||||
self._diarization_index = 0
|
||||
self._translation_index = 0
|
||||
|
||||
self.all_tokens : list[ASRToken] = []
|
||||
self.all_diarization_segments: list[SpeakerSegment] = []
|
||||
self.all_translation_segments = []
|
||||
|
||||
self.new_tokens : list[ASRToken] = []
|
||||
self.new_diarization: list[SpeakerSegment] = []
|
||||
self.new_translation = []
|
||||
self.new_translation_buffer = TimedText()
|
||||
self.new_tokens_buffer = []
|
||||
self.sep = sep if sep is not None else ' '
|
||||
self.beg_loop = None
|
||||
|
||||
def update(self):
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
|
||||
self.new_translation, self.state.new_translation = self.state.new_translation, []
|
||||
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
|
||||
|
||||
self.all_tokens.extend(self.new_tokens)
|
||||
self.all_diarization_segments.extend(self.new_diarization)
|
||||
# self.all_translation_segments.extend(self.new_translation) #future
|
||||
self.all_translation_segments = self.new_translation if self.new_translation != [] else self.all_translation_segments
|
||||
self.new_translation_buffer = self.state.new_translation_buffer if self.new_translation else self.new_translation_buffer
|
||||
self.new_translation_buffer = self.new_translation_buffer if type(self.new_translation_buffer) == str else self.new_translation_buffer.text
|
||||
|
||||
def add_translation(self, line : Line):
|
||||
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + self.sep
|
||||
elif line.translation:
|
||||
break
|
||||
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None):
|
||||
segments = []
|
||||
segment_start_idx = 0
|
||||
for i, token in enumerate(self.all_tokens):
|
||||
if token.is_silence():
|
||||
previous_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
segments.append(previous_segment)
|
||||
segment = Segment.from_tokens(
|
||||
tokens=[token],
|
||||
is_silence=True
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
else:
|
||||
if token.is_punctuation():
|
||||
segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i+1],
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
final_segment = Segment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx:],
|
||||
)
|
||||
if final_segment:
|
||||
segments.append(final_segment)
|
||||
return segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self):
|
||||
if not self.all_diarization_segments:
|
||||
return []
|
||||
merged = [self.all_diarization_segments[0]]
|
||||
for segment in self.all_diarization_segments[1:]:
|
||||
if segment.speaker == merged[-1].speaker:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
|
||||
|
||||
@staticmethod
|
||||
def intersection_duration(seg1, seg2):
|
||||
start = max(seg1.start, seg2.start)
|
||||
end = min(seg1.end, seg2.end)
|
||||
|
||||
return max(0, end - start)
|
||||
|
||||
def get_lines_diarization(self):
|
||||
"""
|
||||
use compute_punctuations_segments, concatenate_diar_segments, intersection_duration
|
||||
"""
|
||||
diarization_buffer = ''
|
||||
punctuation_segments = self.compute_punctuations_segments()
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
for punctuation_segment in punctuation_segments:
|
||||
if not punctuation_segment.is_silence():
|
||||
if diarization_segments and punctuation_segment.start >= diarization_segments[-1].end:
|
||||
diarization_buffer += punctuation_segment.text
|
||||
else:
|
||||
max_overlap = 0.0
|
||||
max_overlap_speaker = 1
|
||||
for diarization_segment in diarization_segments:
|
||||
intersec = self.intersection_duration(punctuation_segment, diarization_segment)
|
||||
if intersec > max_overlap:
|
||||
max_overlap = intersec
|
||||
max_overlap_speaker = diarization_segment.speaker + 1
|
||||
punctuation_segment.speaker = max_overlap_speaker
|
||||
|
||||
lines = []
|
||||
if punctuation_segments:
|
||||
lines = [Line().build_from_segment(punctuation_segments[0])]
|
||||
for segment in punctuation_segments[1:]:
|
||||
if segment.speaker == lines[-1].speaker:
|
||||
if lines[-1].text:
|
||||
lines[-1].text += segment.text
|
||||
lines[-1].end = segment.end
|
||||
else:
|
||||
lines.append(Line().build_from_segment(segment))
|
||||
|
||||
return lines, diarization_buffer
|
||||
|
||||
|
||||
def get_lines(
|
||||
self,
|
||||
diarization=False,
|
||||
translation=False,
|
||||
current_silence=None
|
||||
):
|
||||
"""
|
||||
In the case without diarization
|
||||
"""
|
||||
if diarization:
|
||||
lines, diarization_buffer = self.get_lines_diarization()
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
if token.is_silence():
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
))
|
||||
else:
|
||||
current_line_tokens.append(token)
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
))
|
||||
if translation:
|
||||
[self.add_translation(line) for line in lines if not type(line) == Silence]
|
||||
return lines, diarization_buffer, self.new_translation_buffer
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
Reference in New Issue
Block a user