diff --git a/whisperlivekit/TokensAlignment.py b/whisperlivekit/TokensAlignment.py index 46b2c59..cf1ec47 100644 --- a/whisperlivekit/TokensAlignment.py +++ b/whisperlivekit/TokensAlignment.py @@ -1,17 +1,76 @@ +from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment, Silence +from time import time + + class TokensAlignment: - def __init__(self, state_light, silence=None, args=None): - self.state_light = state_light - self.silence = silence - self.args = args - + def __init__(self, state, args, sep): + self.state = state + self.diarization = args.diarization self._tokens_index = 0 self._diarization_index = 0 self._translation_index = 0 - def update(self): - pass + self.all_tokens = [] + self.all_diarization_segments = [] + self.all_translation_segments = [] + self.new_tokens = [] + self.new_translation = [] + self.new_diarization = [] + self.new_tokens_buffer = [] + self.sep = ' ' + + def update(self): + self.new_tokens, self.state.new_tokens = self.state.new_tokens, [] + self.new_diarization, self.state.new_diarization = self.state.new_diarization, [] + self.new_translation, self.state.new_translation = self.state.new_translation, [] + self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, [] + + self.all_tokens.extend(self.new_tokens) + self.all_diarization_segments.extend(self.new_diarization) + self.all_translation_segments.extend(self.new_translation) + + def create_lines_from_tokens(self, current_silence, beg_loop): + lines = [] + current_line_tokens = [] + for token in self.all_tokens: + if type(token) == Silence: + if current_line_tokens: + lines.append(Line().build_from_tokens(current_line_tokens)) + current_line_tokens = [] + end_silence = token.end if token.has_ended else time() - beg_loop + if lines and lines[-1].speaker == -2: + lines[-1].end = end_silence + else: + lines.append(Line( + speaker = -2, + text = '', + start = token.start, + end = end_silence + )) + else: + current_line_tokens.append(token) + if current_line_tokens: + lines.append(Line().build_from_tokens(current_line_tokens)) + if current_silence: + end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop + if lines and lines[-1].speaker == -2: + lines[-1].end = end_silence + else: + lines.append(Line( + speaker = -2, + text = '', + start = current_silence.start, + end = end_silence + )) + + return lines + + def align_tokens(self): + if not self.diarization: + pass + # return self.all_tokens def compute_punctuations_segments(self): punctuations_breaks = [] @@ -39,130 +98,3 @@ class TokensAlignment: def concatenate_diar_segments(self): diarization_segments = self.state.diarization_segments - -if __name__ == "__main__": - from whisperlivekit.timed_objects import State, ASRToken, SpeakerSegment, Transcript, Silence - - # Reconstruct the state from the backup data - tokens = [ - ASRToken(start=1.38, end=1.48, text=' The'), - ASRToken(start=1.42, end=1.52, text=' description'), - ASRToken(start=1.82, end=1.92, text=' technology'), - ASRToken(start=2.54, end=2.64, text=' has'), - ASRToken(start=2.7, end=2.8, text=' improved'), - ASRToken(start=3.24, end=3.34, text=' so'), - ASRToken(start=3.66, end=3.76, text=' much'), - ASRToken(start=4.02, end=4.12, text=' in'), - ASRToken(start=4.08, end=4.18, text=' the'), - ASRToken(start=4.26, end=4.36, text=' past'), - ASRToken(start=4.48, end=4.58, text=' few'), - ASRToken(start=4.76, end=4.86, text=' years'), - ASRToken(start=5.76, end=5.86, text='.'), - ASRToken(start=5.72, end=5.82, text=' Have'), - ASRToken(start=5.92, end=6.02, text=' you'), - ASRToken(start=6.08, end=6.18, text=' noticed'), - ASRToken(start=6.52, end=6.62, text=' how'), - ASRToken(start=6.8, end=6.9, text=' accurate'), - ASRToken(start=7.46, end=7.56, text=' real'), - ASRToken(start=7.72, end=7.82, text='-time'), - ASRToken(start=8.06, end=8.16, text=' speech'), - ASRToken(start=8.48, end=8.58, text=' to'), - ASRToken(start=8.68, end=8.78, text=' text'), - ASRToken(start=9.0, end=9.1, text=' is'), - ASRToken(start=9.24, end=9.34, text=' now'), - ASRToken(start=9.82, end=9.92, text='?'), - ASRToken(start=9.86, end=9.96, text=' Absolutely'), - ASRToken(start=11.26, end=11.36, text='.'), - ASRToken(start=11.36, end=11.46, text=' I'), - ASRToken(start=11.58, end=11.68, text=' use'), - ASRToken(start=11.78, end=11.88, text=' it'), - ASRToken(start=11.94, end=12.04, text=' all'), - ASRToken(start=12.08, end=12.18, text=' the'), - ASRToken(start=12.32, end=12.42, text=' time'), - ASRToken(start=12.58, end=12.68, text=' for'), - ASRToken(start=12.78, end=12.88, text=' taking'), - ASRToken(start=13.14, end=13.24, text=' notes'), - ASRToken(start=13.4, end=13.5, text=' during'), - ASRToken(start=13.78, end=13.88, text=' meetings'), - ASRToken(start=14.6, end=14.7, text='.'), - ASRToken(start=14.82, end=14.92, text=' It'), - ASRToken(start=14.92, end=15.02, text="'s"), - ASRToken(start=15.04, end=15.14, text=' amazing'), - ASRToken(start=15.5, end=15.6, text=' how'), - ASRToken(start=15.66, end=15.76, text=' it'), - ASRToken(start=15.8, end=15.9, text=' can'), - ASRToken(start=15.96, end=16.06, text=' recognize'), - ASRToken(start=16.58, end=16.68, text=' different'), - ASRToken(start=16.94, end=17.04, text=' speakers'), - ASRToken(start=17.82, end=17.92, text=' and'), - ASRToken(start=18.0, end=18.1, text=' even'), - ASRToken(start=18.42, end=18.52, text=' add'), - ASRToken(start=18.74, end=18.84, text=' punct'), - ASRToken(start=19.02, end=19.12, text='uation'), - ASRToken(start=19.68, end=19.78, text='.'), - ASRToken(start=20.04, end=20.14, text=' Yeah'), - ASRToken(start=20.5, end=20.6, text=','), - ASRToken(start=20.6, end=20.7, text=' but'), - ASRToken(start=20.76, end=20.86, text=' sometimes'), - ASRToken(start=21.42, end=21.52, text=' noise'), - ASRToken(start=21.82, end=21.92, text=' can'), - ASRToken(start=22.08, end=22.18, text=' still'), - ASRToken(start=22.38, end=22.48, text=' cause'), - ASRToken(start=22.72, end=22.82, text=' mistakes'), - ASRToken(start=23.74, end=23.84, text='.'), - ASRToken(start=23.96, end=24.06, text=' Does'), - ASRToken(start=24.16, end=24.26, text=' this'), - ASRToken(start=24.4, end=24.5, text=' system'), - ASRToken(start=24.76, end=24.86, text=' handle'), - ASRToken(start=25.12, end=25.22, text=' that'), - ASRToken(start=25.38, end=25.48, text=' well'), - ASRToken(start=25.68, end=25.78, text='?'), - ASRToken(start=26.4, end=26.5, text=' It'), - ASRToken(start=26.5, end=26.6, text=' does'), - ASRToken(start=26.7, end=26.8, text=' a'), - ASRToken(start=27.08, end=27.18, text=' pretty'), - ASRToken(start=27.12, end=27.22, text=' good'), - ASRToken(start=27.34, end=27.44, text=' job'), - ASRToken(start=27.64, end=27.74, text=' filtering'), - ASRToken(start=28.1, end=28.2, text=' noise'), - ASRToken(start=28.64, end=28.74, text=','), - ASRToken(start=28.78, end=28.88, text=' especially'), - ASRToken(start=29.3, end=29.4, text=' with'), - ASRToken(start=29.51, end=29.61, text=' models'), - ASRToken(start=29.99, end=30.09, text=' that'), - ASRToken(start=30.21, end=30.31, text=' use'), - ASRToken(start=30.51, end=30.61, text=' voice'), - ASRToken(start=30.83, end=30.93, text=' activity'), - ] - - diarization_segments = [ - SpeakerSegment(start=1.3255040645599365, end=4.3255040645599365, speaker=0), - SpeakerSegment(start=4.806154012680054, end=9.806154012680054, speaker=0), - SpeakerSegment(start=9.806154012680054, end=10.806154012680054, speaker=1), - SpeakerSegment(start=11.168735027313232, end=14.168735027313232, speaker=1), - SpeakerSegment(start=14.41029405593872, end=17.41029405593872, speaker=1), - SpeakerSegment(start=17.52983808517456, end=19.52983808517456, speaker=1), - SpeakerSegment(start=19.64953374862671, end=20.066200415293377, speaker=1), - SpeakerSegment(start=20.066200415293377, end=22.64953374862671, speaker=2), - SpeakerSegment(start=23.012792587280273, end=25.012792587280273, speaker=2), - SpeakerSegment(start=25.495875597000122, end=26.41254226366679, speaker=2), - SpeakerSegment(start=26.41254226366679, end=30.495875597000122, speaker=0), - ] - - state = State( - tokens=tokens, - last_validated_token=72, - last_speaker=-1, - last_punctuation_index=71, - translation_validated_segments=[], - buffer_translation=Transcript(start=0, end=0, speaker=-1), - buffer_transcription=Transcript(start=None, end=None, speaker=-1), - diarization_segments=diarization_segments, - end_buffer=31.21587559700018, - end_attributed_speaker=30.495875597000122, - remaining_time_transcription=0.4, - remaining_time_diarization=0.7, - beg_loop=1763627603.968919 - ) - - alignment = TokensAlignment(state) \ No newline at end of file diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 353924d..1f290c5 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -81,10 +81,7 @@ class AudioProcessor: # State management self.is_stopping = False - self.silence = True - self.silence_duration = 0.0 - self.start_silence = None - self.last_silence_dispatch_time = None + self.current_silence = None self.state = State() self.state_light = StateLight() self.lock = asyncio.Lock() @@ -142,33 +139,34 @@ class AudioProcessor: if models.translation_model: self.translation = online_translation_factory(self.args, models.translation_model) - async def _push_silence_event(self, silence_buffer: Silence): + async def _push_silence_event(self): if self.transcription_queue: - await self.transcription_queue.put(silence_buffer) + await self.transcription_queue.put(self.current_silence) if self.args.diarization and self.diarization_queue: - await self.diarization_queue.put(silence_buffer) + await self.diarization_queue.put(self.current_silence) if self.translation_queue: - await self.translation_queue.put(silence_buffer) + await self.translation_queue.put(self.current_silence) async def _begin_silence(self): - if self.silence: + if self.current_silence: return - self.silence = True - now = time() - self.start_silence = now - self.last_silence_dispatch_time = now - await self._push_silence_event(Silence(is_starting=True)) + now = time() - self.beg_loop + self.current_silence = Silence( + is_starting=True, start=now + ) + await self._push_silence_event() async def _end_silence(self): - if not self.silence: + if not self.current_silence: return - now = time() - duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.beg_loop) - await self._push_silence_event(Silence(duration=duration, has_ended=True)) - self.last_silence_dispatch_time = now - self.silence = False - self.start_silence = None - self.last_silence_dispatch_time = None + now = time() - self.beg_loop + self.current_silence.end = now + self.current_silence.is_starting=False + self.current_silence.has_ended=True + self.current_silence.compute_duration() + self.state_light.new_tokens.append(self.current_silence) + await self._push_silence_event() + self.current_silence = None async def _enqueue_active_audio(self, pcm_chunk: np.ndarray): if pcm_chunk is None or pcm_chunk.size == 0: @@ -177,7 +175,6 @@ class AudioProcessor: await self.transcription_queue.put(pcm_chunk.copy()) if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(pcm_chunk.copy()) - self.silence_duration = 0.0 def _slice_before_silence(self, pcm_array, chunk_sample_start, silence_sample): if silence_sample is None: @@ -332,8 +329,7 @@ class AudioProcessor: self.state.tokens.extend(new_tokens) self.state.buffer_transcription = _buffer_transcript self.end_buffer = max(candidate_end_times) - self.state_light.new_tokens = new_tokens - self.state_light.new_tokens += 1 + self.state_light.new_tokens.extend(new_tokens) self.state_light.new_tokens_buffer = _buffer_transcript if self.translation_queue: @@ -412,14 +408,17 @@ class AudioProcessor: await asyncio.sleep(1) continue + self.tokens_alignment.update() + lines = self.tokens_alignment.create_lines_from_tokens(self.current_silence, self.beg_loop) + undiarized_text = '' state = await self.get_current_state() - self.tokens_alignment.compute_punctuations_segments() - lines, undiarized_text = format_output( - state, - self.silence, - args = self.args, - sep=self.sep - ) + # self.tokens_alignment.compute_punctuations_segments() + # lines, undiarized_text = format_output( + # state, + # self.current_silence, + # args = self.args, + # sep=self.sep + # ) if lines and lines[-1].speaker == -2: buffer_transcription = Transcript() else: @@ -581,6 +580,7 @@ class AudioProcessor: if not self.beg_loop: self.beg_loop = time() + self.current_silence = Silence(start=0.0, is_starting=True) if not message: logger.info("Empty audio message received, initiating stop sequence.") @@ -642,17 +642,17 @@ class AudioProcessor: if res is not None: silence_detected = res.get("end", 0) > res.get("start", 0) - if silence_detected and not self.silence: + if silence_detected and not self.current_silence: pre_silence_chunk = self._slice_before_silence( pcm_array, chunk_sample_start, res.get("end") ) if pre_silence_chunk is not None and pre_silence_chunk.size > 0: await self._enqueue_active_audio(pre_silence_chunk) await self._begin_silence() - elif self.silence: + elif self.current_silence: await self._end_silence() - if not self.silence: + if not self.current_silence: await self._enqueue_active_audio(pcm_array) self.total_pcm_samples = chunk_sample_end diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 69f1542..9d50796 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -123,10 +123,16 @@ class Translation(TimedText): @dataclass class Silence(): + start: Optional[float] = None + end: Optional[float] = None duration: Optional[float] = None is_starting: bool = False has_ended: bool = False - + + def compute_duration(self) -> float: + if self.start is None or self.end is None: + return None + self.duration = self.end - self.start @dataclass class Line(TimedText): @@ -145,6 +151,14 @@ class Line(TimedText): _dict['detected_language'] = self.detected_language return _dict + def build_from_tokens(self, tokens: List[ASRToken]): + self.text = ''.join([token.text for token in tokens]) + self.start = tokens[0].start + self.end = tokens[-1].end + self.speaker = 1 + return self + + @dataclass class FrontData(): @@ -197,7 +211,4 @@ class StateLight(): new_tokens: list = field(default_factory=list) new_translation: list = field(default_factory=list) new_diarization: list = field(default_factory=list) - new_tokens_buffer: list = field(default_factory=list) #only when local agreement - new_tokens_index = 0 - new_translation_index = 0 - new_diarization_index = 0 + new_tokens_buffer: list = field(default_factory=list) #only when local agreement \ No newline at end of file