diff --git a/whisperlivekit/TokensAlignment.py b/whisperlivekit/TokensAlignment.py index cf1ec47..f2b0218 100644 --- a/whisperlivekit/TokensAlignment.py +++ b/whisperlivekit/TokensAlignment.py @@ -1,4 +1,5 @@ -from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment, Silence +from whisperlivekit.timed_objects import Line, SilentLine, format_time, SpeakerSegment, Silence +from whisperlivekit.timed_objects import PunctuationSegment from time import time @@ -40,12 +41,10 @@ class TokensAlignment: lines.append(Line().build_from_tokens(current_line_tokens)) current_line_tokens = [] end_silence = token.end if token.has_ended else time() - beg_loop - if lines and lines[-1].speaker == -2: + if lines and lines[-1].is_silent(): lines[-1].end = end_silence else: - lines.append(Line( - speaker = -2, - text = '', + lines.append(SilentLine( start = token.start, end = end_silence )) @@ -55,12 +54,10 @@ class TokensAlignment: lines.append(Line().build_from_tokens(current_line_tokens)) if current_silence: end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop - if lines and lines[-1].speaker == -2: + if lines and lines[-1].is_silent(): lines[-1].end = end_silence else: - lines.append(Line( - speaker = -2, - text = '', + lines.append(SilentLine( start = current_silence.start, end = end_silence )) @@ -73,28 +70,43 @@ class TokensAlignment: # return self.all_tokens def compute_punctuations_segments(self): - punctuations_breaks = [] - new_tokens = self.state.tokens[self.state.last_validated_token:] - for i in range(len(new_tokens)): - token = new_tokens[i] - if token.is_punctuation(): - punctuations_breaks.append({ - 'token_index': i, - 'token': token, - 'start': token.start, - 'end': token.end, - }) - punctuations_segments = [] - for i, break_info in enumerate(punctuations_breaks): - start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0 - end = break_info['end'] - punctuations_segments.append({ - 'start': start, - 'end': end, - 'token_index': break_info['token_index'], - 'token': break_info['token'] - }) - return punctuations_segments + """Compute segments of text between punctuation marks. + + Returns a list of PunctuationSegment objects, each representing + the text from the start (or previous punctuation) to the current punctuation mark. + """ + + if not self.all_tokens: + return [] + punctuation_indices = [ + i for i, token in enumerate(self.all_tokens) + if token.is_punctuation() + ] + if not punctuation_indices: + return [] + + segments = [] + for i, punct_idx in enumerate(punctuation_indices): + start_idx = punctuation_indices[i - 1] + 1 if i > 0 else 0 + end_idx = punct_idx + if start_idx <= end_idx: + segment = PunctuationSegment.from_token_range( + tokens=self.all_tokens, + token_index_start=start_idx, + token_index_end=end_idx, + punctuation_token_index=punct_idx + ) + segments.append(segment) + return segments + def concatenate_diar_segments(self): - diarization_segments = self.state.diarization_segments + if not self.all_diarization_segments: + return [] + merged = [self.all_diarization_segments[0]] + for segment in self.all_diarization_segments[1:]: + if segment.speaker == merged[-1].speaker: + merged[-1].end = segment.end + else: + merged.append(segment) + return merged \ No newline at end of file diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 1f290c5..8263ef3 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) SENTINEL = object() # unique sentinel object for end of stream marker +MILENCE_DURATION = 3 def cut_at(cumulative_pcm, cut_sec): cumulative_len = 0 @@ -164,7 +165,8 @@ class AudioProcessor: self.current_silence.is_starting=False self.current_silence.has_ended=True self.current_silence.compute_duration() - self.state_light.new_tokens.append(self.current_silence) + if self.current_silence.duration > MILENCE_DURATION: + self.state_light.new_tokens.append(self.current_silence) await self._push_silence_event() self.current_silence = None @@ -365,7 +367,6 @@ class AudioProcessor: self.diarization.insert_audio_chunk(item) diarization_segments = await self.diarization.diarize() self.state_light.new_diarization = diarization_segments - self.state_light.new_diarization_index += 1 except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 9d50796..c2310ec 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -158,7 +158,13 @@ class Line(TimedText): self.speaker = 1 return self - + def is_silent(self) -> bool: + return self.speaker == -2 + +class SilentLine(Line): + speaker = -2 + text = '' + @dataclass class FrontData(): @@ -185,6 +191,45 @@ class FrontData(): _dict['error'] = self.error return _dict +@dataclass +class PunctuationSegment(TimedText): + """Represents a segment of text between punctuation marks.""" + token_index_start: int + token_index_end: int + punctuation_token_index: int + punctuation_token: ASRToken + + @classmethod + def from_token_range( + cls, + tokens: List[ASRToken], + token_index_start: int, + token_index_end: int, + punctuation_token_index: int + ) -> "PunctuationSegment": + """Create a PunctuationSegment from a range of tokens ending at a punctuation mark.""" + if not tokens or token_index_start < 0 or token_index_end >= len(tokens): + raise ValueError("Invalid token indices") + + start_token = tokens[token_index_start] + end_token = tokens[token_index_end] + punctuation_token = tokens[punctuation_token_index] + + # Build text from tokens in the segment + segment_tokens = tokens[token_index_start:token_index_end + 1] + text = ''.join(token.text for token in segment_tokens) + + return cls( + start=start_token.start, + end=end_token.end, + text=text, + token_index_start=token_index_start, + token_index_end=token_index_end, + punctuation_token_index=punctuation_token_index, + punctuation_token=punctuation_token + ) + + @dataclass class ChangeSpeaker: speaker: int