mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
Compare commits
1 Commits
feature/vo
...
new-api
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33cd160d58 |
@@ -179,4 +179,5 @@ def online_translation_factory(args, translation_model):
|
|||||||
#one shared nllb model for all speaker
|
#one shared nllb model for all speaker
|
||||||
#one tokenizer per speaker/language
|
#one tokenizer per speaker/language
|
||||||
from nllw import OnlineTranslation
|
from nllw import OnlineTranslation
|
||||||
|
from nllw import OnlineTranslation
|
||||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ def no_token_to_silence(tokens):
|
|||||||
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||||
current_time = time() - (beg_loop if beg_loop else 0.0)
|
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||||
last_token = tokens[-1]
|
last_token = tokens[-1]
|
||||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
silence_duration = current_time - last_token.end
|
||||||
|
if (vac_detected_silence and silence_duration > END_SILENCE_DURATION_VAC) or (silence_duration >= END_SILENCE_DURATION):
|
||||||
if last_token.speaker == -2:
|
if last_token.speaker == -2:
|
||||||
last_token.end = current_time
|
last_token.end = current_time
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from whisperlivekit.remove_silences import handle_silences
|
from whisperlivekit.remove_silences import handle_silences
|
||||||
from whisperlivekit.timed_objects import Line, format_time
|
from whisperlivekit.timed_objects import Line, Segment, format_time
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@@ -72,83 +72,96 @@ def format_output(state, silence, args, sep):
|
|||||||
token.corrected_speaker = 1
|
token.corrected_speaker = 1
|
||||||
token.validated_speaker = True
|
token.validated_speaker = True
|
||||||
else:
|
else:
|
||||||
if is_punctuation(token):
|
if is_punctuation(token):
|
||||||
last_punctuation = i
|
last_punctuation = i
|
||||||
|
|
||||||
if last_punctuation == i-1:
|
if last_punctuation == i-1:
|
||||||
if token.speaker != previous_speaker:
|
if token.speaker != previous_speaker:
|
||||||
|
token.validated_speaker = True
|
||||||
|
# perfect, diarization perfectly aligned
|
||||||
|
last_punctuation = None
|
||||||
|
else:
|
||||||
|
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||||
|
if speaker_change_pos:
|
||||||
|
# Corrects delay:
|
||||||
|
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||||
|
# should become:
|
||||||
|
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||||
|
token.corrected_speaker = new_speaker
|
||||||
token.validated_speaker = True
|
token.validated_speaker = True
|
||||||
# perfect, diarization perfectly aligned
|
elif speaker != previous_speaker:
|
||||||
last_punctuation = None
|
if not (speaker == -2 or previous_speaker == -2):
|
||||||
else:
|
if next_punctuation_change(i, tokens):
|
||||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
# Corrects advance:
|
||||||
if speaker_change_pos:
|
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||||
# Corrects delay:
|
# should become:
|
||||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||||
# should become:
|
token.corrected_speaker = previous_speaker
|
||||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
token.validated_speaker = True
|
||||||
token.corrected_speaker = new_speaker
|
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||||
token.validated_speaker = True
|
if not disable_punctuation_split:
|
||||||
elif speaker != previous_speaker:
|
|
||||||
if not (speaker == -2 or previous_speaker == -2):
|
|
||||||
if next_punctuation_change(i, tokens):
|
|
||||||
# Corrects advance:
|
|
||||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
|
||||||
# should become:
|
|
||||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
|
||||||
token.corrected_speaker = previous_speaker
|
token.corrected_speaker = previous_speaker
|
||||||
token.validated_speaker = True
|
token.validated_speaker = False
|
||||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
|
||||||
if not disable_punctuation_split:
|
|
||||||
token.corrected_speaker = previous_speaker
|
|
||||||
token.validated_speaker = False
|
|
||||||
if token.validated_speaker:
|
if token.validated_speaker:
|
||||||
state.last_validated_token = i
|
state.last_validated_token = i + last_validated_token
|
||||||
previous_speaker = token.corrected_speaker
|
previous_speaker = token.corrected_speaker
|
||||||
|
|
||||||
previous_speaker = 1
|
for token in tokens[last_validated_token+1:state.last_validated_token+1]:
|
||||||
|
if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
|
||||||
|
state.segments.append(
|
||||||
|
Segment(
|
||||||
|
speaker=token.corrected_speaker,
|
||||||
|
words=[token]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state.segments[-1].words.append(token)
|
||||||
|
|
||||||
|
for token in tokens[state.last_validated_token+1:]:
|
||||||
|
# if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
|
||||||
|
# state.segments.append(
|
||||||
|
# Segment(
|
||||||
|
# speaker=token.corrected_speaker,
|
||||||
|
# buffer_tokens=[token]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
state.segments[-1].buffer_tokens.append(token)
|
||||||
|
|
||||||
|
for segment in state.segments:
|
||||||
|
segment.consolidate(sep)
|
||||||
|
# lines = []
|
||||||
|
# for token in tokens:
|
||||||
|
# if int(token.corrected_speaker) != int(previous_speaker):
|
||||||
|
# lines.append(new_line(token))
|
||||||
|
# else:
|
||||||
|
# append_token_to_last_line(lines, sep, token)
|
||||||
|
|
||||||
|
# previous_speaker = token.corrected_speaker
|
||||||
|
|
||||||
|
for ts in translation_validated_segments:
|
||||||
|
for segment in state.segments[state.last_validated_segment:]:
|
||||||
|
if ts.is_within(segment):
|
||||||
|
segment.translation += ts.text + sep
|
||||||
|
break
|
||||||
|
|
||||||
|
for ts in translation_buffer:
|
||||||
|
for segment in state.segments[state.last_validated_segment:]:
|
||||||
|
if ts.is_within(segment):
|
||||||
|
segment.buffer.translation += ts.text + sep
|
||||||
|
break
|
||||||
|
|
||||||
|
# if state.buffer_transcription and lines:
|
||||||
|
# lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
for token in tokens:
|
for segment in state.segments:
|
||||||
if int(token.corrected_speaker) != int(previous_speaker):
|
lines.append(Line(
|
||||||
lines.append(new_line(token))
|
start=segment.start,
|
||||||
else:
|
end=segment.end,
|
||||||
append_token_to_last_line(lines, sep, token)
|
speaker=segment.speaker,
|
||||||
|
text=segment.text,
|
||||||
previous_speaker = token.corrected_speaker
|
translation=segment.translation
|
||||||
|
))
|
||||||
if lines:
|
|
||||||
unassigned_translated_segments = []
|
|
||||||
for ts in translation_validated_segments:
|
|
||||||
assigned = False
|
|
||||||
for line in lines:
|
|
||||||
if ts and ts.overlaps_with(line):
|
|
||||||
if ts.is_within(line):
|
|
||||||
line.translation += ts.text + ' '
|
|
||||||
assigned = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
|
||||||
if ts0 and line.overlaps_with(ts0):
|
|
||||||
line.translation += ts0.text + ' '
|
|
||||||
if ts1:
|
|
||||||
unassigned_translated_segments.append(ts1)
|
|
||||||
assigned = True
|
|
||||||
break
|
|
||||||
if not assigned:
|
|
||||||
unassigned_translated_segments.append(ts)
|
|
||||||
|
|
||||||
if unassigned_translated_segments:
|
|
||||||
for line in lines:
|
|
||||||
remaining_segments = []
|
|
||||||
for ts in unassigned_translated_segments:
|
|
||||||
if ts and ts.overlaps_with(line):
|
|
||||||
line.translation += ts.text + ' '
|
|
||||||
else:
|
|
||||||
remaining_segments.append(ts)
|
|
||||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
|
||||||
|
|
||||||
if state.buffer_transcription and lines:
|
|
||||||
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
|
||||||
|
|
||||||
return lines, undiarized_text
|
return lines, undiarized_text
|
||||||
|
|||||||
@@ -91,8 +91,12 @@ class SpeakerSegment(TimedText):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Translation(TimedText):
|
class Translation(TimedText):
|
||||||
|
is_validated : bool = False
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# def split(self):
|
||||||
|
# return self.text.split(" ") # should be customized with the sep
|
||||||
|
|
||||||
def approximate_cut_at(self, cut_time):
|
def approximate_cut_at(self, cut_time):
|
||||||
"""
|
"""
|
||||||
Each word in text is considered to be of duration (end-start)/len(words in text)
|
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||||
@@ -120,6 +124,19 @@ class Translation(TimedText):
|
|||||||
|
|
||||||
return segment0, segment1
|
return segment0, segment1
|
||||||
|
|
||||||
|
def cut_position(self, position):
|
||||||
|
sep=" "
|
||||||
|
words = self.text.split(sep)
|
||||||
|
num_words = len(words)
|
||||||
|
duration_per_word = self.duration() / num_words
|
||||||
|
cut_time=duration_per_word*position
|
||||||
|
|
||||||
|
text0 = sep.join(words[:position])
|
||||||
|
text1 = sep.join(words[position:])
|
||||||
|
|
||||||
|
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||||
|
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||||
|
return segment0, segment1
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Silence():
|
class Silence():
|
||||||
@@ -143,6 +160,90 @@ class Line(TimedText):
|
|||||||
_dict['detected_language'] = self.detected_language
|
_dict['detected_language'] = self.detected_language
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordValidation:
|
||||||
|
"""Validation status for word-level data."""
|
||||||
|
text: bool = False
|
||||||
|
speaker: bool = False
|
||||||
|
language: bool = False
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'text': self.text,
|
||||||
|
'speaker': self.speaker,
|
||||||
|
'language': self.language
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Word:
|
||||||
|
"""Word-level object with timing and validation information."""
|
||||||
|
text: str = ''
|
||||||
|
start: float = 0.0
|
||||||
|
end: float = 0.0
|
||||||
|
validated: WordValidation = field(default_factory=WordValidation)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'text': self.text,
|
||||||
|
'start': self.start,
|
||||||
|
'end': self.end,
|
||||||
|
'validated': self.validated.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SegmentBuffer:
|
||||||
|
"""Per-segment temporary buffers for ephemeral data."""
|
||||||
|
transcription: str = ''
|
||||||
|
diarization: str = ''
|
||||||
|
translation: str = ''
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'transcription': self.transcription,
|
||||||
|
'diarization': self.diarization,
|
||||||
|
'translation': self.translation
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Segment:
|
||||||
|
"""Represents a segment in the new API structure."""
|
||||||
|
id: int = 0
|
||||||
|
speaker: int = -1
|
||||||
|
text: str = ''
|
||||||
|
start_speaker: float = 0.0
|
||||||
|
start: float = 0.0
|
||||||
|
end: float = 0.0
|
||||||
|
language: Optional[str] = None
|
||||||
|
translation: str = ''
|
||||||
|
words: List[ASRToken] = field(default_factory=list)
|
||||||
|
buffer_tokens: List[ASRToken] = field(default_factory=list)
|
||||||
|
buffer_translation = ''
|
||||||
|
buffer: SegmentBuffer = field(default_factory=SegmentBuffer)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert segment to dictionary for JSON serialization."""
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'speaker': self.speaker,
|
||||||
|
'text': self.text,
|
||||||
|
'start_speaker': self.start_speaker,
|
||||||
|
'start': self.start,
|
||||||
|
'end': self.end,
|
||||||
|
'language': self.language,
|
||||||
|
'translation': self.translation,
|
||||||
|
'words': [word.to_dict() for word in self.words],
|
||||||
|
'buffer': self.buffer.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
def consolidate(self, sep):
|
||||||
|
self.text = sep.join([word.text for word in self.words])
|
||||||
|
if self.words:
|
||||||
|
self.start = self.words[0].start
|
||||||
|
self.end = self.words[-1].end
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrontData():
|
class FrontData():
|
||||||
@@ -175,7 +276,9 @@ class ChangeSpeaker:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class State():
|
class State():
|
||||||
tokens: list = field(default_factory=list)
|
tokens: list = field(default_factory=list)
|
||||||
|
segments: list = field(default_factory=list)
|
||||||
last_validated_token: int = 0
|
last_validated_token: int = 0
|
||||||
|
last_validated_segment: int = 0 # validated means tokens speaker and transcription are validated and terminated
|
||||||
translation_validated_segments: list = field(default_factory=list)
|
translation_validated_segments: list = field(default_factory=list)
|
||||||
translation_buffer: list = field(default_factory=list)
|
translation_buffer: list = field(default_factory=list)
|
||||||
buffer_transcription: str = field(default_factory=Transcript)
|
buffer_transcription: str = field(default_factory=Transcript)
|
||||||
@@ -183,4 +286,4 @@ class State():
|
|||||||
end_attributed_speaker: float = 0.0
|
end_attributed_speaker: float = 0.0
|
||||||
remaining_time_transcription: float = 0.0
|
remaining_time_transcription: float = 0.0
|
||||||
remaining_time_diarization: float = 0.0
|
remaining_time_diarization: float = 0.0
|
||||||
beg_loop: Optional[int] = None
|
beg_loop: Optional[int] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user