This commit is contained in:
Quentin Fuxa
2024-11-01 23:52:00 +01:00
parent 41ca17acda
commit 33cd160d58
4 changed files with 192 additions and 74 deletions

View File

@@ -179,4 +179,5 @@ def online_translation_factory(args, translation_model):
#one shared nllb model for all speaker
#one tokenizer per speaker/language
from nllw import OnlineTranslation
from nllw import OnlineTranslation
return OnlineTranslation(translation_model, [args.lan], [args.target_language])

View File

@@ -81,7 +81,8 @@ def no_token_to_silence(tokens):
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
silence_duration = current_time - last_token.end
if (vac_detected_silence and silence_duration > END_SILENCE_DURATION_VAC) or (silence_duration >= END_SILENCE_DURATION):
if last_token.speaker == -2:
last_token.end = current_time
else:

View File

@@ -1,7 +1,7 @@
import logging
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time
from whisperlivekit.timed_objects import Line, Segment, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -72,83 +72,96 @@ def format_output(state, silence, args, sep):
token.corrected_speaker = 1
token.validated_speaker = True
else:
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if token.speaker != previous_speaker:
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if token.speaker != previous_speaker:
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = False
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
state.last_validated_token = i + last_validated_token
previous_speaker = token.corrected_speaker
previous_speaker = 1
for token in tokens[last_validated_token+1:state.last_validated_token+1]:
if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
state.segments.append(
Segment(
speaker=token.corrected_speaker,
words=[token]
)
)
else:
state.segments[-1].words.append(token)
for token in tokens[state.last_validated_token+1:]:
# if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
# state.segments.append(
# Segment(
# speaker=token.corrected_speaker,
# buffer_tokens=[token]
# )
# )
# else:
state.segments[-1].buffer_tokens.append(token)
for segment in state.segments:
segment.consolidate(sep)
# lines = []
# for token in tokens:
# if int(token.corrected_speaker) != int(previous_speaker):
# lines.append(new_line(token))
# else:
# append_token_to_last_line(lines, sep, token)
# previous_speaker = token.corrected_speaker
for ts in translation_validated_segments:
for segment in state.segments[state.last_validated_segment:]:
if ts.is_within(segment):
segment.translation += ts.text + sep
break
for ts in translation_buffer:
for segment in state.segments[state.last_validated_segment:]:
if ts.is_within(segment):
segment.buffer.translation += ts.text + sep
break
# if state.buffer_transcription and lines:
# lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
lines = []
for token in tokens:
if int(token.corrected_speaker) != int(previous_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
previous_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
for segment in state.segments:
lines.append(Line(
start=segment.start,
end=segment.end,
speaker=segment.speaker,
text=segment.text,
translation=segment.translation
))
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text

View File

@@ -91,8 +91,12 @@ class SpeakerSegment(TimedText):
@dataclass
class Translation(TimedText):
is_validated : bool = False
pass
# def split(self):
# return self.text.split(" ") # should be customized with the sep
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
@@ -120,6 +124,19 @@ class Translation(TimedText):
return segment0, segment1
def cut_position(self, position):
sep=" "
words = self.text.split(sep)
num_words = len(words)
duration_per_word = self.duration() / num_words
cut_time=duration_per_word*position
text0 = sep.join(words[:position])
text1 = sep.join(words[position:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass
class Silence():
@@ -143,6 +160,90 @@ class Line(TimedText):
_dict['detected_language'] = self.detected_language
return _dict
@dataclass
class WordValidation:
"""Validation status for word-level data."""
text: bool = False
speaker: bool = False
language: bool = False
def to_dict(self):
return {
'text': self.text,
'speaker': self.speaker,
'language': self.language
}
@dataclass
class Word:
"""Word-level object with timing and validation information."""
text: str = ''
start: float = 0.0
end: float = 0.0
validated: WordValidation = field(default_factory=WordValidation)
def to_dict(self):
return {
'text': self.text,
'start': self.start,
'end': self.end,
'validated': self.validated.to_dict()
}
@dataclass
class SegmentBuffer:
"""Per-segment temporary buffers for ephemeral data."""
transcription: str = ''
diarization: str = ''
translation: str = ''
def to_dict(self):
return {
'transcription': self.transcription,
'diarization': self.diarization,
'translation': self.translation
}
@dataclass
class Segment:
"""Represents a segment in the new API structure."""
id: int = 0
speaker: int = -1
text: str = ''
start_speaker: float = 0.0
start: float = 0.0
end: float = 0.0
language: Optional[str] = None
translation: str = ''
words: List[ASRToken] = field(default_factory=list)
buffer_tokens: List[ASRToken] = field(default_factory=list)
buffer_translation = ''
buffer: SegmentBuffer = field(default_factory=SegmentBuffer)
def to_dict(self):
"""Convert segment to dictionary for JSON serialization."""
return {
'id': self.id,
'speaker': self.speaker,
'text': self.text,
'start_speaker': self.start_speaker,
'start': self.start,
'end': self.end,
'language': self.language,
'translation': self.translation,
'words': [word.to_dict() for word in self.words],
'buffer': self.buffer.to_dict()
}
def consolidate(self, sep):
self.text = sep.join([word.text for word in self.words])
if self.words:
self.start = self.words[0].start
self.end = self.words[-1].end
@dataclass
class FrontData():
@@ -175,7 +276,9 @@ class ChangeSpeaker:
@dataclass
class State():
tokens: list = field(default_factory=list)
segments: list = field(default_factory=list)
last_validated_token: int = 0
last_validated_segment: int = 0 # validated means tokens speaker and transcription are validated and terminated
translation_validated_segments: list = field(default_factory=list)
translation_buffer: list = field(default_factory=list)
buffer_transcription: str = field(default_factory=Transcript)
@@ -183,4 +286,4 @@ class State():
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
beg_loop: Optional[int] = None
beg_loop: Optional[int] = None