From c83fd179a8f0959b91731a3f17f076c33dc3a0ea Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 24 Aug 2024 19:15:00 +0200 Subject: [PATCH] improves phase shift correction between transcription and diarization --- whisperlivekit/results_formater.py | 113 ++++++++++++++++++++++------- 1 file changed, 85 insertions(+), 28 deletions(-) diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 6b9ba44..5d1931f 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -7,20 +7,55 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) PUNCTUATION_MARKS = {'.', '!', '?'} +CHECK_AROUND = 4 def format_time(seconds: float) -> str: """Format seconds as HH:MM:SS.""" return str(timedelta(seconds=int(seconds))) -def check_punctuation_nearby(i, tokens): - if i < len(tokens): - for ind in range(i, min(len(tokens), i+1)): #we check in the next 1 tokens - if tokens[ind].text.strip() in PUNCTUATION_MARKS: - return True +def is_punctuation(token): + if token.text.strip() in PUNCTUATION_MARKS: + return True return False + +def next_punctuation_change(i, tokens): + for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)): + if is_punctuation(tokens[ind]): + return ind + return None + +def next_speaker_change(i, tokens, speaker): + for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1): + token = tokens[ind] + if is_punctuation(token): + break + if token.speaker != speaker: + return ind, token.speaker + return None, speaker - + +def new_line( + token, + speaker, + last_end_diarized, + debug_info = "" +): + return { + "speaker": int(speaker), + "text": token.text + debug_info, + "beg": format_time(token.start), + "end": format_time(token.end), + "diff": round(token.end - last_end_diarized, 2) + } + + +def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized): + if token.text: + lines[-1]["text"] += sep + token.text + debug_info + lines[-1]["end"] = format_time(token.end) + lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + def format_output(state, silence, current_time, diarization, debug): tokens = state["tokens"] @@ -34,13 +69,12 @@ def format_output(state, silence, current_time, diarization, debug): last_end_diarized = 0 undiarized_text = [] tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence) + last_punctuation = None for i, token in enumerate(tokens): speaker = token.speaker - if len(tokens) == 1 and not diarization: - if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1' - speaker = 1 - + if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1' + speaker = 1 if diarization and not tokens[-1].speaker == -2: if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: undiarized_text.append(token.text) @@ -53,23 +87,46 @@ def format_output(state, silence, current_time, diarization, debug): debug_info = "" if debug: debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]" - if speaker != previous_speaker or not lines: - if speaker != previous_speaker and lines and check_punctuation_nearby(i, tokens): # check if punctuation nearby - lines[-1]["text"] += sep + token.text + debug_info - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + + if not lines: + lines.append(new_line(token, speaker, last_end_diarized, debug_info = "")) + continue + else: + previous_speaker = lines[-1]['speaker'] + + if is_punctuation(token): + last_punctuation = i + + + if last_punctuation == i-1: + if speaker != previous_speaker: + # perfect, diarization perfectly aligned + lines.append(new_line(token, speaker, last_end_diarized, debug_info = "")) + last_punctuation, next_punctuation = None, None + continue + + speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker) + if speaker_change_pos: + # Corrects delay: + # That was the idea. Okay haha |SPLIT SPEAKER| that's a good one + # should become: + # That was the idea. |SPLIT SPEAKER| Okay haha that's a good one + lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = "")) else: - lines.append({ - "speaker": int(speaker), - "text": token.text + debug_info, - "beg": format_time(token.start), - "end": format_time(token.end), - "diff": round(token.end - last_end_diarized, 2) - }) - previous_speaker = speaker - elif token.text: # Only append if text isn't empty - lines[-1]["text"] += sep + token.text + debug_info - lines[-1]["end"] = format_time(token.end) - lines[-1]["diff"] = round(token.end - last_end_diarized, 2) + # No speaker change to come + append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized) + continue + + + if speaker != previous_speaker: + if next_punctuation_change(i, tokens): + # Corrects advance: + # Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely + # should become: + # Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely + append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized) + continue + + append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized) + return lines, undiarized_text, buffer_transcription, '' - return lines, undiarized_text, buffer_transcription, '' \ No newline at end of file