mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
internal rework 3
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment, Silence
|
||||
from whisperlivekit.timed_objects import Line, SilentLine, format_time, SpeakerSegment, Silence
|
||||
from whisperlivekit.timed_objects import PunctuationSegment
|
||||
from time import time
|
||||
|
||||
|
||||
@@ -40,12 +41,10 @@ class TokensAlignment:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
end_silence = token.end if token.has_ended else time() - beg_loop
|
||||
if lines and lines[-1].speaker == -2:
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(Line(
|
||||
speaker = -2,
|
||||
text = '',
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
))
|
||||
@@ -55,12 +54,10 @@ class TokensAlignment:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop
|
||||
if lines and lines[-1].speaker == -2:
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
else:
|
||||
lines.append(Line(
|
||||
speaker = -2,
|
||||
text = '',
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
))
|
||||
@@ -73,28 +70,43 @@ class TokensAlignment:
|
||||
# return self.all_tokens
|
||||
|
||||
def compute_punctuations_segments(self):
|
||||
punctuations_breaks = []
|
||||
new_tokens = self.state.tokens[self.state.last_validated_token:]
|
||||
for i in range(len(new_tokens)):
|
||||
token = new_tokens[i]
|
||||
if token.is_punctuation():
|
||||
punctuations_breaks.append({
|
||||
'token_index': i,
|
||||
'token': token,
|
||||
'start': token.start,
|
||||
'end': token.end,
|
||||
})
|
||||
punctuations_segments = []
|
||||
for i, break_info in enumerate(punctuations_breaks):
|
||||
start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0
|
||||
end = break_info['end']
|
||||
punctuations_segments.append({
|
||||
'start': start,
|
||||
'end': end,
|
||||
'token_index': break_info['token_index'],
|
||||
'token': break_info['token']
|
||||
})
|
||||
return punctuations_segments
|
||||
"""Compute segments of text between punctuation marks.
|
||||
|
||||
Returns a list of PunctuationSegment objects, each representing
|
||||
the text from the start (or previous punctuation) to the current punctuation mark.
|
||||
"""
|
||||
|
||||
if not self.all_tokens:
|
||||
return []
|
||||
punctuation_indices = [
|
||||
i for i, token in enumerate(self.all_tokens)
|
||||
if token.is_punctuation()
|
||||
]
|
||||
if not punctuation_indices:
|
||||
return []
|
||||
|
||||
segments = []
|
||||
for i, punct_idx in enumerate(punctuation_indices):
|
||||
start_idx = punctuation_indices[i - 1] + 1 if i > 0 else 0
|
||||
end_idx = punct_idx
|
||||
if start_idx <= end_idx:
|
||||
segment = PunctuationSegment.from_token_range(
|
||||
tokens=self.all_tokens,
|
||||
token_index_start=start_idx,
|
||||
token_index_end=end_idx,
|
||||
punctuation_token_index=punct_idx
|
||||
)
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self):
|
||||
diarization_segments = self.state.diarization_segments
|
||||
if not self.all_diarization_segments:
|
||||
return []
|
||||
merged = [self.all_diarization_segments[0]]
|
||||
for segment in self.all_diarization_segments[1:]:
|
||||
if segment.speaker == merged[-1].speaker:
|
||||
merged[-1].end = segment.end
|
||||
else:
|
||||
merged.append(segment)
|
||||
return merged
|
||||
@@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
MILENCE_DURATION = 3
|
||||
|
||||
def cut_at(cumulative_pcm, cut_sec):
|
||||
cumulative_len = 0
|
||||
@@ -164,7 +165,8 @@ class AudioProcessor:
|
||||
self.current_silence.is_starting=False
|
||||
self.current_silence.has_ended=True
|
||||
self.current_silence.compute_duration()
|
||||
self.state_light.new_tokens.append(self.current_silence)
|
||||
if self.current_silence.duration > MILENCE_DURATION:
|
||||
self.state_light.new_tokens.append(self.current_silence)
|
||||
await self._push_silence_event()
|
||||
self.current_silence = None
|
||||
|
||||
@@ -365,7 +367,6 @@ class AudioProcessor:
|
||||
self.diarization.insert_audio_chunk(item)
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
self.state_light.new_diarization = diarization_segments
|
||||
self.state_light.new_diarization_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
|
||||
@@ -158,7 +158,13 @@ class Line(TimedText):
|
||||
self.speaker = 1
|
||||
return self
|
||||
|
||||
|
||||
def is_silent(self) -> bool:
|
||||
return self.speaker == -2
|
||||
|
||||
class SilentLine(Line):
|
||||
speaker = -2
|
||||
text = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontData():
|
||||
@@ -185,6 +191,45 @@ class FrontData():
|
||||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
class PunctuationSegment(TimedText):
|
||||
"""Represents a segment of text between punctuation marks."""
|
||||
token_index_start: int
|
||||
token_index_end: int
|
||||
punctuation_token_index: int
|
||||
punctuation_token: ASRToken
|
||||
|
||||
@classmethod
|
||||
def from_token_range(
|
||||
cls,
|
||||
tokens: List[ASRToken],
|
||||
token_index_start: int,
|
||||
token_index_end: int,
|
||||
punctuation_token_index: int
|
||||
) -> "PunctuationSegment":
|
||||
"""Create a PunctuationSegment from a range of tokens ending at a punctuation mark."""
|
||||
if not tokens or token_index_start < 0 or token_index_end >= len(tokens):
|
||||
raise ValueError("Invalid token indices")
|
||||
|
||||
start_token = tokens[token_index_start]
|
||||
end_token = tokens[token_index_end]
|
||||
punctuation_token = tokens[punctuation_token_index]
|
||||
|
||||
# Build text from tokens in the segment
|
||||
segment_tokens = tokens[token_index_start:token_index_end + 1]
|
||||
text = ''.join(token.text for token in segment_tokens)
|
||||
|
||||
return cls(
|
||||
start=start_token.start,
|
||||
end=end_token.end,
|
||||
text=text,
|
||||
token_index_start=token_index_start,
|
||||
token_index_end=token_index_end,
|
||||
punctuation_token_index=punctuation_token_index,
|
||||
punctuation_token=punctuation_token
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
|
||||
Reference in New Issue
Block a user