diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 46db751..3150c7e 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -380,10 +380,10 @@ class AudioProcessor: item = await get_all_from_queue(self.diarization_queue) if item is SENTINEL: logger.debug("Diarization processor received sentinel. Finishing.") - self.diarization_queue.task_done() break - elif type(item) is Silence and item.has_ended: - diarization_obj.insert_silence(item.duration) + elif type(item) is Silence: + if item.has_ended: + diarization_obj.insert_silence(item.duration) continue elif isinstance(item, np.ndarray): pcm_array = item @@ -425,13 +425,10 @@ class AudioProcessor: ) if len(self.state.tokens) > 0: self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker) - self.diarization_queue.task_done() except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") - if 'pcm_array' in locals() and pcm_array is not SENTINEL: - self.diarization_queue.task_done() logger.info("Diarization processor task finished.") async def translation_processor(self): diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index f5a64bf..86622fb 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -9,22 +9,16 @@ logger.setLevel(logging.DEBUG) CHECK_AROUND = 4 DEBUG = False - -def is_punctuation(token): - if token.is_punctuation(): - return True - return False - def next_punctuation_change(i, tokens): for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)): - if is_punctuation(tokens[ind]): + if tokens[ind].is_punctuation(): return ind return None def next_speaker_change(i, tokens, speaker): for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1): token = tokens[ind] - if is_punctuation(token): + if token.is_punctuation(): break if token.speaker != speaker: return ind, token.speaker @@ -58,8 +52,8 @@ def format_output(state, silence, args, sep): tokens = state.tokens translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments last_validated_token = state.last_validated_token - - previous_speaker = 1 + + last_speaker = abs(state.last_speaker) undiarized_text = [] tokens = handle_silences(tokens, state.beg_loop, silence) for i in range(last_validated_token, len(tokens)): @@ -71,50 +65,54 @@ def format_output(state, silence, args, sep): token.corrected_speaker = 1 token.validated_speaker = True else: - if is_punctuation(token): - state.last_punctuation_index = i - - if state.last_punctuation_index == i-1: - if token.speaker != previous_speaker: + if token.speaker == -1: + undiarized_text.append(token.text) + elif token.is_punctuation(): + state.last_punctuation_index = i + token.corrected_speaker = last_speaker + token.validated_speaker = True + elif state.last_punctuation_index == i-1: + if token.speaker != last_speaker: + token.corrected_speaker = token.speaker + token.validated_speaker = True + # perfect, diarization perfectly aligned + else: + speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker) + if speaker_change_pos: + # Corrects delay: + # That was the idea. haha |SPLIT SPEAKER| that's a good one + # should become: + # That was the idea. |SPLIT SPEAKER| haha that's a good one + token.corrected_speaker = new_speaker token.validated_speaker = True - # perfect, diarization perfectly aligned - last_punctuation = None - else: - speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker) - if speaker_change_pos: - # Corrects delay: - # That was the idea. haha |SPLIT SPEAKER| that's a good one - # should become: - # That was the idea. |SPLIT SPEAKER| haha that's a good one - token.corrected_speaker = new_speaker - token.validated_speaker = True - elif speaker != previous_speaker: - if not (speaker == -2 or previous_speaker == -2): - if next_punctuation_change(i, tokens): - # Corrects advance: - # Are you |SPLIT SPEAKER| ? yeah, sure. Absolutely - # should become: - # Are you ? |SPLIT SPEAKER| yeah, sure. Absolutely - token.corrected_speaker = previous_speaker - token.validated_speaker = True - else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True. - if not disable_punctuation_split: - token.corrected_speaker = previous_speaker - token.validated_speaker = False + elif speaker != last_speaker: + if not (speaker == -2 or last_speaker == -2): + if next_punctuation_change(i, tokens): + # Corrects advance: + # Are you |SPLIT SPEAKER| ? yeah, sure. Absolutely + # should become: + # Are you ? |SPLIT SPEAKER| yeah, sure. Absolutely + token.corrected_speaker = last_speaker + token.validated_speaker = True + else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True. + if not disable_punctuation_split: + token.corrected_speaker = last_speaker + token.validated_speaker = False if token.validated_speaker: state.last_validated_token = i - previous_speaker = token.corrected_speaker + state.last_speaker = token.corrected_speaker - previous_speaker = 1 + last_speaker = 1 lines = [] for token in tokens: - if int(token.corrected_speaker) != int(previous_speaker): - lines.append(new_line(token)) - else: - append_token_to_last_line(lines, sep, token) + if token.corrected_speaker != -1: + if int(token.corrected_speaker) != int(last_speaker): + lines.append(new_line(token)) + else: + append_token_to_last_line(lines, sep, token) - previous_speaker = token.corrected_speaker + last_speaker = token.corrected_speaker if lines: unassigned_translated_segments = [] diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 03222b0..2a73e8b 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -180,6 +180,7 @@ class ChangeSpeaker: class State(): tokens: list = field(default_factory=list) last_validated_token: int = 0 + last_speaker: int = 1 last_punctuation_index: Optional[int] = None translation_validated_segments: list = field(default_factory=list) buffer_translation: str = field(default_factory=Transcript)