improves phase shift correction between transcription and diarization

This commit is contained in:
Quentin Fuxa
2024-08-24 19:15:00 +02:00
parent 5258305745
commit c83fd179a8

View File

@@ -7,20 +7,55 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
PUNCTUATION_MARKS = {'.', '!', '?'}
CHECK_AROUND = 4
def format_time(seconds: float) -> str:
"""Format seconds as HH:MM:SS."""
return str(timedelta(seconds=int(seconds)))
def check_punctuation_nearby(i, tokens):
if i < len(tokens):
for ind in range(i, min(len(tokens), i+1)): #we check in the next 1 tokens
if tokens[ind].text.strip() in PUNCTUATION_MARKS:
return True
def is_punctuation(token):
if token.text.strip() in PUNCTUATION_MARKS:
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
break
if token.speaker != speaker:
return ind, token.speaker
return None, speaker
def new_line(
token,
speaker,
last_end_diarized,
debug_info = ""
):
return {
"speaker": int(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
}
def append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized):
if token.text:
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
def format_output(state, silence, current_time, diarization, debug):
tokens = state["tokens"]
@@ -34,13 +69,12 @@ def format_output(state, silence, current_time, diarization, debug):
last_end_diarized = 0
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
if len(tokens) == 1 and not diarization:
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if not diarization and speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
speaker = 1
if diarization and not tokens[-1].speaker == -2:
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
undiarized_text.append(token.text)
@@ -53,23 +87,46 @@ def format_output(state, silence, current_time, diarization, debug):
debug_info = ""
if debug:
debug_info = f"[{format_time(token.start)} : {format_time(token.end)}]"
if speaker != previous_speaker or not lines:
if speaker != previous_speaker and lines and check_punctuation_nearby(i, tokens): # check if punctuation nearby
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
if not lines:
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
continue
else:
previous_speaker = lines[-1]['speaker']
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if speaker != previous_speaker:
# perfect, diarization perfectly aligned
lines.append(new_line(token, speaker, last_end_diarized, debug_info = ""))
last_punctuation, next_punctuation = None, None
continue
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. Okay haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| Okay haha that's a good one
lines.append(new_line(token, new_speaker, last_end_diarized, debug_info = ""))
else:
lines.append({
"speaker": int(speaker),
"text": token.text + debug_info,
"beg": format_time(token.start),
"end": format_time(token.end),
"diff": round(token.end - last_end_diarized, 2)
})
previous_speaker = speaker
elif token.text: # Only append if text isn't empty
lines[-1]["text"] += sep + token.text + debug_info
lines[-1]["end"] = format_time(token.end)
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
# No speaker change to come
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
if speaker != previous_speaker:
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| okay? yeah, sure. Absolutely
# should become:
# Are you okay? |SPLIT SPEAKER| yeah, sure. Absolutely
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
continue
append_token_to_last_line(lines, sep, token, debug_info, last_end_diarized)
return lines, undiarized_text, buffer_transcription, ''
return lines, undiarized_text, buffer_transcription, ''