mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
internal rework 4
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, format_time, SpeakerSegment, Silence
|
||||
from whisperlivekit.timed_objects import PunctuationSegment
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence
|
||||
from whisperlivekit.timed_objects import PunctuationSegment
|
||||
|
||||
ALIGNMENT_TIME_TOLERANCE = 0.2 # seconds
|
||||
|
||||
|
||||
class TokensAlignment:
|
||||
@@ -12,15 +16,16 @@ class TokensAlignment:
|
||||
self._diarization_index = 0
|
||||
self._translation_index = 0
|
||||
|
||||
self.all_tokens = []
|
||||
self.all_diarization_segments = []
|
||||
self.all_tokens : list[ASRToken] = []
|
||||
self.all_diarization_segments: list[SpeakerSegment] = []
|
||||
self.all_translation_segments = []
|
||||
|
||||
self.new_tokens = []
|
||||
self.new_tokens : list[ASRToken] = []
|
||||
self.new_diarization: list[SpeakerSegment] = []
|
||||
self.new_translation = []
|
||||
self.new_diarization = []
|
||||
self.new_tokens_buffer = []
|
||||
self.sep = ' '
|
||||
self.sep = sep if sep is not None else ' '
|
||||
self.beg_loop = None
|
||||
|
||||
def update(self):
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
@@ -32,7 +37,10 @@ class TokensAlignment:
|
||||
self.all_diarization_segments.extend(self.new_diarization)
|
||||
self.all_translation_segments.extend(self.new_translation)
|
||||
|
||||
def create_lines_from_tokens(self, current_silence, beg_loop):
|
||||
def get_lines(self, current_silence):
|
||||
"""
|
||||
In the case without diarization
|
||||
"""
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
@@ -40,7 +48,7 @@ class TokensAlignment:
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - beg_loop
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
@@ -53,7 +61,7 @@ class TokensAlignment:
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
@@ -64,22 +72,104 @@ class TokensAlignment:
|
||||
|
||||
return lines
|
||||
|
||||
def align_tokens(self):
|
||||
if not self.diarization:
|
||||
pass
|
||||
# return self.all_tokens
|
||||
|
||||
def compute_punctuations_segments(self):
|
||||
def _get_asr_tokens(self) -> list[ASRToken]:
|
||||
return [token for token in self.all_tokens if isinstance(token, ASRToken)]
|
||||
|
||||
def _tokens_to_text(self, tokens: list[ASRToken]) -> str:
|
||||
return ''.join(token.text for token in tokens)
|
||||
|
||||
def _extract_detected_language(self, tokens: list[ASRToken]):
|
||||
for token in tokens:
|
||||
if getattr(token, 'detected_language', None):
|
||||
return token.detected_language
|
||||
return None
|
||||
|
||||
def _speaker_display_id(self, raw_speaker) -> int:
|
||||
if isinstance(raw_speaker, int):
|
||||
speaker_index = raw_speaker
|
||||
else:
|
||||
digits = ''.join(ch for ch in str(raw_speaker) if ch.isdigit())
|
||||
speaker_index = int(digits) if digits else 0
|
||||
return speaker_index + 1 if speaker_index >= 0 else 0
|
||||
|
||||
def _line_from_tokens(self, tokens: list[ASRToken], speaker: int) -> Line:
|
||||
line = Line().build_from_tokens(tokens)
|
||||
line.speaker = speaker
|
||||
detected_language = self._extract_detected_language(tokens)
|
||||
if detected_language:
|
||||
line.detected_language = detected_language
|
||||
return line
|
||||
|
||||
def _find_initial_diar_index(self, diar_segments: list[SpeakerSegment], start_time: float) -> int:
|
||||
for idx, segment in enumerate(diar_segments):
|
||||
if segment.end + ALIGNMENT_TIME_TOLERANCE >= start_time:
|
||||
return idx
|
||||
return len(diar_segments)
|
||||
|
||||
def _find_speaker_for_token(self, token: ASRToken, diar_segments: list[SpeakerSegment], diar_idx: int):
|
||||
if not diar_segments:
|
||||
return None, diar_idx
|
||||
idx = min(diar_idx, len(diar_segments) - 1)
|
||||
midpoint = (token.start + token.end) / 2 if token.end is not None else token.start
|
||||
|
||||
while idx < len(diar_segments) and diar_segments[idx].end + ALIGNMENT_TIME_TOLERANCE < midpoint:
|
||||
idx += 1
|
||||
|
||||
candidate_indices = []
|
||||
if idx < len(diar_segments):
|
||||
candidate_indices.append(idx)
|
||||
if idx > 0:
|
||||
candidate_indices.append(idx - 1)
|
||||
|
||||
for candidate_idx in candidate_indices:
|
||||
segment = diar_segments[candidate_idx]
|
||||
seg_start = (segment.start or 0) - ALIGNMENT_TIME_TOLERANCE
|
||||
seg_end = (segment.end or 0) + ALIGNMENT_TIME_TOLERANCE
|
||||
if seg_start <= midpoint <= seg_end:
|
||||
return segment.speaker, candidate_idx
|
||||
|
||||
return None, idx
|
||||
|
||||
def _build_lines_for_tokens(self, tokens: list[ASRToken], diar_segments: list[SpeakerSegment], diar_idx: int):
|
||||
if not tokens:
|
||||
return [], diar_idx
|
||||
|
||||
segment_lines: list[Line] = []
|
||||
current_tokens: list[ASRToken] = []
|
||||
current_speaker = None
|
||||
pointer = diar_idx
|
||||
|
||||
for token in tokens:
|
||||
speaker_raw, pointer = self._find_speaker_for_token(token, diar_segments, pointer)
|
||||
if speaker_raw is None:
|
||||
return [], diar_idx
|
||||
speaker = self._speaker_display_id(speaker_raw)
|
||||
if current_speaker is None or current_speaker != speaker:
|
||||
if current_tokens:
|
||||
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
|
||||
current_tokens = [token]
|
||||
current_speaker = speaker
|
||||
else:
|
||||
current_tokens.append(token)
|
||||
|
||||
if current_tokens:
|
||||
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
|
||||
|
||||
return segment_lines, pointer
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None):
|
||||
"""Compute segments of text between punctuation marks.
|
||||
|
||||
Returns a list of PunctuationSegment objects, each representing
|
||||
the text from the start (or previous punctuation) to the current punctuation mark.
|
||||
"""
|
||||
|
||||
if not self.all_tokens:
|
||||
tokens = tokens if tokens is not None else self._get_asr_tokens()
|
||||
if not tokens:
|
||||
return []
|
||||
punctuation_indices = [
|
||||
i for i, token in enumerate(self.all_tokens)
|
||||
i for i, token in enumerate[ASRToken](tokens)
|
||||
if token.is_punctuation()
|
||||
]
|
||||
if not punctuation_indices:
|
||||
@@ -91,7 +181,7 @@ class TokensAlignment:
|
||||
end_idx = punct_idx
|
||||
if start_idx <= end_idx:
|
||||
segment = PunctuationSegment.from_token_range(
|
||||
tokens=self.all_tokens,
|
||||
tokens=tokens,
|
||||
token_index_start=start_idx,
|
||||
token_index_end=end_idx,
|
||||
punctuation_token_index=punct_idx
|
||||
@@ -109,4 +199,42 @@ class TokensAlignment:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
return merged
|
||||
|
||||
def get_lines(self, diarization=False, translation=False):
|
||||
"""
|
||||
Align diarization speaker segments with punctuation-delimited transcription
|
||||
segments (see docs/alignement_principles.md).
|
||||
"""
|
||||
tokens = self._get_asr_tokens()
|
||||
if not tokens:
|
||||
return [], ''
|
||||
|
||||
punctuation_segments = self.compute_punctuations_segments(tokens=tokens)
|
||||
diar_segments = self.concatenate_diar_segments()
|
||||
|
||||
if not punctuation_segments or not diar_segments:
|
||||
return [], self._tokens_to_text(tokens)
|
||||
|
||||
max_diar_end = diar_segments[-1].end
|
||||
if max_diar_end is None:
|
||||
return [], self._tokens_to_text(tokens)
|
||||
|
||||
lines: list[Line] = []
|
||||
last_consumed_index = -1
|
||||
diar_idx = self._find_initial_diar_index(diar_segments, tokens[0].start or 0)
|
||||
|
||||
for segment in punctuation_segments:
|
||||
if segment.end is None or segment.end > max_diar_end:
|
||||
break
|
||||
slice_tokens = tokens[segment.token_index_start:segment.token_index_end + 1]
|
||||
segment_lines, diar_idx = self._build_lines_for_tokens(slice_tokens, diar_segments, diar_idx)
|
||||
if not segment_lines:
|
||||
break
|
||||
lines.extend(segment_lines)
|
||||
last_consumed_index = segment.token_index_end
|
||||
|
||||
buffer_tokens = tokens[last_consumed_index + 1:] if last_consumed_index + 1 < len(tokens) else []
|
||||
buffer_diarization = self._tokens_to_text(buffer_tokens)
|
||||
|
||||
return lines, buffer_diarization
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
MILENCE_DURATION = 3
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
def cut_at(cumulative_pcm, cut_sec):
|
||||
cumulative_len = 0
|
||||
@@ -165,7 +165,7 @@ class AudioProcessor:
|
||||
self.current_silence.is_starting=False
|
||||
self.current_silence.has_ended=True
|
||||
self.current_silence.compute_duration()
|
||||
if self.current_silence.duration > MILENCE_DURATION:
|
||||
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
|
||||
self.state_light.new_tokens.append(self.current_silence)
|
||||
await self._push_silence_event()
|
||||
self.current_silence = None
|
||||
@@ -410,57 +410,32 @@ class AudioProcessor:
|
||||
continue
|
||||
|
||||
self.tokens_alignment.update()
|
||||
lines = self.tokens_alignment.create_lines_from_tokens(self.current_silence, self.beg_loop)
|
||||
undiarized_text = ''
|
||||
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
|
||||
diarization=self.args.diarization,
|
||||
translation=self.args.translation
|
||||
)
|
||||
state = await self.get_current_state()
|
||||
# self.tokens_alignment.compute_punctuations_segments()
|
||||
# lines, undiarized_text = format_output(
|
||||
# state,
|
||||
# self.current_silence,
|
||||
# args = self.args,
|
||||
# sep=self.sep
|
||||
# )
|
||||
if lines and lines[-1].speaker == -2:
|
||||
buffer_transcription = Transcript()
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
|
||||
buffer_diarization = ''
|
||||
if undiarized_text:
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
|
||||
async with self.lock:
|
||||
self.state.end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
buffer_translation_text = ''
|
||||
if state.buffer_translation:
|
||||
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
|
||||
if raw_buffer_translation:
|
||||
buffer_translation_text = raw_buffer_translation.strip()
|
||||
|
||||
buffer_transcription_text = ''
|
||||
buffer_diarization_text = ''
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
if not lines and not buffer_transcription_text and not buffer_diarization_text:
|
||||
response_status = "no_audio_detected"
|
||||
lines = []
|
||||
elif not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.end_buffer,
|
||||
end=state.end_buffer
|
||||
)]
|
||||
|
||||
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_transcription=buffer_transcription_text,
|
||||
buffer_diarization=buffer_diarization_text,
|
||||
buffer_translation=buffer_translation_text,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
|
||||
should_push = (response != self.last_response_content)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
if should_push:
|
||||
yield response
|
||||
self.last_response_content = response
|
||||
|
||||
@@ -582,6 +557,7 @@ class AudioProcessor:
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
self.current_silence = Silence(start=0.0, is_starting=True)
|
||||
self.tokens_alignment.beg_loop = self.beg_loop
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
|
||||
@@ -162,8 +162,10 @@ class Line(TimedText):
|
||||
return self.speaker == -2
|
||||
|
||||
class SilentLine(Line):
|
||||
speaker = -2
|
||||
text = ''
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.speaker = -2
|
||||
self.text = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -192,8 +194,10 @@ class FrontData():
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
class PunctuationSegment(TimedText):
|
||||
class PunctuationSegment():
|
||||
"""Represents a segment of text between punctuation marks."""
|
||||
start: Optional[float]
|
||||
end: Optional[float]
|
||||
token_index_start: int
|
||||
token_index_end: int
|
||||
punctuation_token_index: int
|
||||
|
||||
Reference in New Issue
Block a user