mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
language detection after few seconds working
This commit is contained in:
@@ -59,7 +59,6 @@ class AudioProcessor:
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = Transcript()
|
||||
self.buffer_diarization = ""
|
||||
self.end_buffer = 0
|
||||
self.end_attributed_speaker = 0
|
||||
self.lock = asyncio.Lock()
|
||||
@@ -142,7 +141,6 @@ class AudioProcessor:
|
||||
tokens=self.tokens.copy(),
|
||||
translated_segments=self.translated_segments.copy(),
|
||||
buffer_transcription=self.buffer_transcription,
|
||||
buffer_diarization=self.buffer_diarization,
|
||||
end_buffer=self.end_buffer,
|
||||
end_attributed_speaker=self.end_attributed_speaker,
|
||||
remaining_time_transcription=remaining_transcription,
|
||||
@@ -154,7 +152,7 @@ class AudioProcessor:
|
||||
async with self.lock:
|
||||
self.tokens = []
|
||||
self.translated_segments = []
|
||||
self.buffer_transcription = self.buffer_diarization = Transcript()
|
||||
self.buffer_transcription = Transcript()
|
||||
self.end_buffer = self.end_attributed_speaker = 0
|
||||
self.beg_loop = time()
|
||||
|
||||
@@ -297,7 +295,7 @@ class AudioProcessor:
|
||||
|
||||
async def diarization_processor(self, diarization_obj):
|
||||
"""Process audio chunks for speaker diarization."""
|
||||
buffer_diarization = ""
|
||||
buffer_diarization = Transcript()
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
while True:
|
||||
try:
|
||||
@@ -318,15 +316,15 @@ class AudioProcessor:
|
||||
# Process diarization
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
segments = diarization_obj.get_segments()
|
||||
|
||||
async with self.lock:
|
||||
self.tokens, last_segment = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
# if last_segment is not None and last_segment.speaker != self.last_detected_speaker:
|
||||
# if not self.speaker_languages.get(last_segment.speaker, None):
|
||||
@@ -423,23 +421,15 @@ class AudioProcessor:
|
||||
)
|
||||
if end_w_silence:
|
||||
buffer_transcription = Transcript()
|
||||
buffer_diarization = Transcript()
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
buffer_diarization = state.buffer_diarization
|
||||
|
||||
# Handle undiarized text
|
||||
buffer_diarization = ''
|
||||
if undiarized_text:
|
||||
combined = self.sep.join(undiarized_text)
|
||||
if buffer_transcription:
|
||||
combined += self.sep
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
|
||||
async with self.lock:
|
||||
self.end_attributed_speaker = state.end_attributed_speaker
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
buffer_diarization.text = combined
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
@@ -456,7 +446,7 @@ class AudioProcessor:
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text,
|
||||
buffer_diarization=buffer_transcription.text,
|
||||
buffer_diarization=buffer_diarization,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
|
||||
@@ -242,7 +242,7 @@ class DiartDiarization:
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
else:
|
||||
tokens = add_speaker_to_tokens(segments, tokens)
|
||||
return tokens, segments[-1]
|
||||
return tokens
|
||||
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
|
||||
@@ -296,7 +296,7 @@ class SortformerDiarizationOnline:
|
||||
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens, None
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
@@ -313,7 +313,7 @@ class SortformerDiarizationOnline:
|
||||
# Use punctuation-aware assignment (similar to diart_backend)
|
||||
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens, segments[-1]
|
||||
return tokens
|
||||
|
||||
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""
|
||||
|
||||
@@ -38,12 +38,16 @@ def new_line(
|
||||
text = token.text + debug_info,
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token, debug_info):
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + debug_info
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
|
||||
def format_output(state, silence, current_time, args, debug, sep):
|
||||
diarization = args.diarization
|
||||
|
||||
@@ -108,9 +108,13 @@ class SimulStreamingOnlineProcessor:
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
timestamped_words, timestamped_buffer_language = self.model.infer(is_last=is_last)
|
||||
self.buffer = timestamped_buffer_language
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
if timestamped_words and timestamped_words[0].detected_language == None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
|
||||
|
||||
|
||||
@@ -78,7 +78,6 @@ class PaddedAlignAttWhisper:
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.global_time_offset = 0.0
|
||||
self.reset_tokenizer_to_auto_next_call = False
|
||||
self.sentence_start_time = 0.0
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
@@ -153,7 +152,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
|
||||
self.second_word_timestamp = None
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
@@ -261,7 +260,6 @@ class PaddedAlignAttWhisper:
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
logger.debug("keeping last two segments because they are and it is not complete.")
|
||||
self.segments = self.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
@@ -434,14 +432,17 @@ class PaddedAlignAttWhisper:
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
if self.cfg.language == "auto" and self.detected_language is None:
|
||||
seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
|
||||
if seconds_since_start >= 3.0:
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.second_word_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.second_word_timestamp
|
||||
if seconds_since_start >= 5.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.refresh_segment(complete=True)
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
self.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
else:
|
||||
@@ -590,6 +591,10 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
if len(l_absolute_timestamps) >=2 and self.second_word_timestamp is None:
|
||||
self.second_word_timestamp = l_absolute_timestamps[1]
|
||||
|
||||
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
@@ -604,15 +609,10 @@ class PaddedAlignAttWhisper:
|
||||
end=current_timestamp + 0.1,
|
||||
text= word,
|
||||
probability=0.95,
|
||||
language=self.detected_language
|
||||
detected_language=self.detected_language
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
if self.detected_language is None and self.cfg.language == "auto":
|
||||
timestamped_buffer_language, timestamped_words = timestamped_words, []
|
||||
else:
|
||||
timestamped_buffer_language = []
|
||||
|
||||
return timestamped_words, timestamped_buffer_language
|
||||
return timestamped_words
|
||||
|
||||
@@ -17,7 +17,7 @@ class TimedText:
|
||||
speaker: Optional[int] = -1
|
||||
probability: Optional[float] = None
|
||||
is_dummy: Optional[bool] = False
|
||||
language: str = None
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
def is_punctuation(self):
|
||||
return self.text.strip() in PUNCTUATION_MARKS
|
||||
@@ -41,11 +41,11 @@ class TimedText:
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
@@ -123,7 +123,6 @@ class Silence():
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
detected_language: str = None
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
|
||||
Reference in New Issue
Block a user