mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
diarization corrections
This commit is contained in:
@@ -380,10 +380,10 @@ class AudioProcessor:
|
||||
item = await get_all_from_queue(self.diarization_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Diarization processor received sentinel. Finishing.")
|
||||
self.diarization_queue.task_done()
|
||||
break
|
||||
elif type(item) is Silence and item.has_ended:
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
elif type(item) is Silence:
|
||||
if item.has_ended:
|
||||
diarization_obj.insert_silence(item.duration)
|
||||
continue
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
@@ -425,13 +425,10 @@ class AudioProcessor:
|
||||
)
|
||||
if len(self.state.tokens) > 0:
|
||||
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
|
||||
self.diarization_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL:
|
||||
self.diarization_queue.task_done()
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self):
|
||||
|
||||
@@ -9,22 +9,16 @@ logger.setLevel(logging.DEBUG)
|
||||
CHECK_AROUND = 4
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def is_punctuation(token):
|
||||
if token.is_punctuation():
|
||||
return True
|
||||
return False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if is_punctuation(tokens[ind]):
|
||||
if tokens[ind].is_punctuation():
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if is_punctuation(token):
|
||||
if token.is_punctuation():
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
@@ -58,8 +52,8 @@ def format_output(state, silence, args, sep):
|
||||
tokens = state.tokens
|
||||
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
previous_speaker = 1
|
||||
|
||||
last_speaker = abs(state.last_speaker)
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
@@ -71,50 +65,54 @@ def format_output(state, silence, args, sep):
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
if is_punctuation(token):
|
||||
state.last_punctuation_index = i
|
||||
|
||||
if state.last_punctuation_index == i-1:
|
||||
if token.speaker != previous_speaker:
|
||||
if token.speaker == -1:
|
||||
undiarized_text.append(token.text)
|
||||
elif token.is_punctuation():
|
||||
state.last_punctuation_index = i
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
elif state.last_punctuation_index == i-1:
|
||||
if token.speaker != last_speaker:
|
||||
token.corrected_speaker = token.speaker
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
last_punctuation = None
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
elif speaker != previous_speaker:
|
||||
if not (speaker == -2 or previous_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = previous_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = previous_speaker
|
||||
token.validated_speaker = False
|
||||
elif speaker != last_speaker:
|
||||
if not (speaker == -2 or last_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = False
|
||||
if token.validated_speaker:
|
||||
state.last_validated_token = i
|
||||
previous_speaker = token.corrected_speaker
|
||||
state.last_speaker = token.corrected_speaker
|
||||
|
||||
previous_speaker = 1
|
||||
last_speaker = 1
|
||||
|
||||
lines = []
|
||||
for token in tokens:
|
||||
if int(token.corrected_speaker) != int(previous_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
if token.corrected_speaker != -1:
|
||||
if int(token.corrected_speaker) != int(last_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
|
||||
previous_speaker = token.corrected_speaker
|
||||
last_speaker = token.corrected_speaker
|
||||
|
||||
if lines:
|
||||
unassigned_translated_segments = []
|
||||
|
||||
@@ -180,6 +180,7 @@ class ChangeSpeaker:
|
||||
class State():
|
||||
tokens: list = field(default_factory=list)
|
||||
last_validated_token: int = 0
|
||||
last_speaker: int = 1
|
||||
last_punctuation_index: Optional[int] = None
|
||||
translation_validated_segments: list = field(default_factory=list)
|
||||
buffer_translation: str = field(default_factory=Transcript)
|
||||
|
||||
Reference in New Issue
Block a user