diarization corrections

This commit is contained in:
Quentin Fuxa
2025-11-19 19:06:03 +01:00
parent 3104f40f6e
commit 11e9def0b2
3 changed files with 48 additions and 52 deletions

View File

@@ -380,10 +380,10 @@ class AudioProcessor:
item = await get_all_from_queue(self.diarization_queue)
if item is SENTINEL:
logger.debug("Diarization processor received sentinel. Finishing.")
self.diarization_queue.task_done()
break
elif type(item) is Silence and item.has_ended:
diarization_obj.insert_silence(item.duration)
elif type(item) is Silence:
if item.has_ended:
diarization_obj.insert_silence(item.duration)
continue
elif isinstance(item, np.ndarray):
pcm_array = item
@@ -425,13 +425,10 @@ class AudioProcessor:
)
if len(self.state.tokens) > 0:
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
self.diarization_queue.task_done()
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
self.diarization_queue.task_done()
logger.info("Diarization processor task finished.")
async def translation_processor(self):

View File

@@ -9,22 +9,16 @@ logger.setLevel(logging.DEBUG)
CHECK_AROUND = 4
DEBUG = False
def is_punctuation(token):
if token.is_punctuation():
return True
return False
def next_punctuation_change(i, tokens):
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
if is_punctuation(tokens[ind]):
if tokens[ind].is_punctuation():
return ind
return None
def next_speaker_change(i, tokens, speaker):
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
token = tokens[ind]
if is_punctuation(token):
if token.is_punctuation():
break
if token.speaker != speaker:
return ind, token.speaker
@@ -58,8 +52,8 @@ def format_output(state, silence, args, sep):
tokens = state.tokens
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
last_validated_token = state.last_validated_token
previous_speaker = 1
last_speaker = abs(state.last_speaker)
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
for i in range(last_validated_token, len(tokens)):
@@ -71,50 +65,54 @@ def format_output(state, silence, args, sep):
token.corrected_speaker = 1
token.validated_speaker = True
else:
if is_punctuation(token):
state.last_punctuation_index = i
if state.last_punctuation_index == i-1:
if token.speaker != previous_speaker:
if token.speaker == -1:
undiarized_text.append(token.text)
elif token.is_punctuation():
state.last_punctuation_index = i
token.corrected_speaker = last_speaker
token.validated_speaker = True
elif state.last_punctuation_index == i-1:
if token.speaker != last_speaker:
token.corrected_speaker = token.speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = False
elif speaker != last_speaker:
if not (speaker == -2 or last_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = last_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = last_speaker
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
previous_speaker = token.corrected_speaker
state.last_speaker = token.corrected_speaker
previous_speaker = 1
last_speaker = 1
lines = []
for token in tokens:
if int(token.corrected_speaker) != int(previous_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
if token.corrected_speaker != -1:
if int(token.corrected_speaker) != int(last_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
previous_speaker = token.corrected_speaker
last_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []

View File

@@ -180,6 +180,7 @@ class ChangeSpeaker:
class State():
tokens: list = field(default_factory=list)
last_validated_token: int = 0
last_speaker: int = 1
last_punctuation_index: Optional[int] = None
translation_validated_segments: list = field(default_factory=list)
buffer_translation: str = field(default_factory=Transcript)