internal rework 3

This commit is contained in:
Quentin Fuxa
2025-11-20 22:28:30 +01:00
parent b7c1cc77cc
commit 270faf2069
3 changed files with 93 additions and 35 deletions

View File

@@ -1,4 +1,5 @@
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment, Silence
from whisperlivekit.timed_objects import Line, SilentLine, format_time, SpeakerSegment, Silence
from whisperlivekit.timed_objects import PunctuationSegment
from time import time
@@ -40,12 +41,10 @@ class TokensAlignment:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - beg_loop
if lines and lines[-1].speaker == -2:
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(Line(
speaker = -2,
text = '',
lines.append(SilentLine(
start = token.start,
end = end_silence
))
@@ -55,12 +54,10 @@ class TokensAlignment:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop
if lines and lines[-1].speaker == -2:
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
lines.append(Line(
speaker = -2,
text = '',
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
))
@@ -73,28 +70,43 @@ class TokensAlignment:
# return self.all_tokens
def compute_punctuations_segments(self):
punctuations_breaks = []
new_tokens = self.state.tokens[self.state.last_validated_token:]
for i in range(len(new_tokens)):
token = new_tokens[i]
if token.is_punctuation():
punctuations_breaks.append({
'token_index': i,
'token': token,
'start': token.start,
'end': token.end,
})
punctuations_segments = []
for i, break_info in enumerate(punctuations_breaks):
start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0
end = break_info['end']
punctuations_segments.append({
'start': start,
'end': end,
'token_index': break_info['token_index'],
'token': break_info['token']
})
return punctuations_segments
"""Compute segments of text between punctuation marks.
Returns a list of PunctuationSegment objects, each representing
the text from the start (or previous punctuation) to the current punctuation mark.
"""
if not self.all_tokens:
return []
punctuation_indices = [
i for i, token in enumerate(self.all_tokens)
if token.is_punctuation()
]
if not punctuation_indices:
return []
segments = []
for i, punct_idx in enumerate(punctuation_indices):
start_idx = punctuation_indices[i - 1] + 1 if i > 0 else 0
end_idx = punct_idx
if start_idx <= end_idx:
segment = PunctuationSegment.from_token_range(
tokens=self.all_tokens,
token_index_start=start_idx,
token_index_end=end_idx,
punctuation_token_index=punct_idx
)
segments.append(segment)
return segments
def concatenate_diar_segments(self):
diarization_segments = self.state.diarization_segments
if not self.all_diarization_segments:
return []
merged = [self.all_diarization_segments[0]]
for segment in self.all_diarization_segments[1:]:
if segment.speaker == merged[-1].speaker:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged

View File

@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
MILENCE_DURATION = 3
def cut_at(cumulative_pcm, cut_sec):
cumulative_len = 0
@@ -164,7 +165,8 @@ class AudioProcessor:
self.current_silence.is_starting=False
self.current_silence.has_ended=True
self.current_silence.compute_duration()
self.state_light.new_tokens.append(self.current_silence)
if self.current_silence.duration > MILENCE_DURATION:
self.state_light.new_tokens.append(self.current_silence)
await self._push_silence_event()
self.current_silence = None
@@ -365,7 +367,6 @@ class AudioProcessor:
self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize()
self.state_light.new_diarization = diarization_segments
self.state_light.new_diarization_index += 1
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")

View File

@@ -158,7 +158,13 @@ class Line(TimedText):
self.speaker = 1
return self
def is_silent(self) -> bool:
return self.speaker == -2
class SilentLine(Line):
speaker = -2
text = ''
@dataclass
class FrontData():
@@ -185,6 +191,45 @@ class FrontData():
_dict['error'] = self.error
return _dict
@dataclass
class PunctuationSegment(TimedText):
"""Represents a segment of text between punctuation marks."""
token_index_start: int
token_index_end: int
punctuation_token_index: int
punctuation_token: ASRToken
@classmethod
def from_token_range(
cls,
tokens: List[ASRToken],
token_index_start: int,
token_index_end: int,
punctuation_token_index: int
) -> "PunctuationSegment":
"""Create a PunctuationSegment from a range of tokens ending at a punctuation mark."""
if not tokens or token_index_start < 0 or token_index_end >= len(tokens):
raise ValueError("Invalid token indices")
start_token = tokens[token_index_start]
end_token = tokens[token_index_end]
punctuation_token = tokens[punctuation_token_index]
# Build text from tokens in the segment
segment_tokens = tokens[token_index_start:token_index_end + 1]
text = ''.join(token.text for token in segment_tokens)
return cls(
start=start_token.start,
end=end_token.end,
text=text,
token_index_start=token_index_start,
token_index_end=token_index_end,
punctuation_token_index=punctuation_token_index,
punctuation_token=punctuation_token
)
@dataclass
class ChangeSpeaker:
speaker: int