diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 2ffbc0f..5d77064 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -4,7 +4,7 @@ from time import time, sleep import math import logging import traceback -from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State +from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.results_formater import format_output @@ -58,7 +58,7 @@ class AudioProcessor: self.silence_duration = 0.0 self.tokens = [] self.translated_segments = [] - self.buffer_transcription = "" + self.buffer_transcription = Transcript() self.buffer_diarization = "" self.end_buffer = 0 self.end_attributed_speaker = 0 @@ -114,20 +114,6 @@ class AudioProcessor: """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 - async def update_transcription(self, new_tokens, buffer, end_buffer): - """Thread-safe update of transcription with new data.""" - async with self.lock: - self.tokens.extend(new_tokens) - self.buffer_transcription = buffer - self.end_buffer = end_buffer - - async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): - """Thread-safe update of diarization with new data.""" - async with self.lock: - self.end_attributed_speaker = end_attributed_speaker - if buffer_diarization: - self.buffer_diarization = buffer_diarization - async def add_dummy_token(self): """Placeholder token when no transcription is available.""" async with self.lock: @@ -168,7 +154,7 @@ class AudioProcessor: async with self.lock: self.tokens = [] self.translated_segments = [] - self.buffer_transcription = self.buffer_diarization = "" + self.buffer_transcription = self.buffer_diarization = Transcript() self.end_buffer = self.end_attributed_speaker = 0 self.beg_loop = time() @@ -264,30 +250,28 @@ class AudioProcessor: self.online.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.online.process_iter) - # Get buffer information - _buffer_transcript_obj = self.online.get_buffer() - buffer_text = _buffer_transcript_obj.text + _buffer_transcript = self.online.get_buffer() + buffer_text = _buffer_transcript.text if new_tokens: validated_text = self.sep.join([t.text for t in new_tokens]) if buffer_text.startswith(validated_text): - buffer_text = buffer_text[len(validated_text):].lstrip() + _buffer_transcript.text = buffer_text[len(validated_text):].lstrip() candidate_end_times = [self.end_buffer] if new_tokens: candidate_end_times.append(new_tokens[-1].end) - if _buffer_transcript_obj.end is not None: - candidate_end_times.append(_buffer_transcript_obj.end) + if _buffer_transcript.end is not None: + candidate_end_times.append(_buffer_transcript.end) candidate_end_times.append(current_audio_processed_upto) - new_end_buffer = max(candidate_end_times) - - await self.update_transcription( - new_tokens, buffer_text, new_end_buffer - ) + async with self.lock: + self.tokens.extend(new_tokens) + self.buffer_transcription = _buffer_transcript + self.end_buffer = max(candidate_end_times) if self.translation_queue: for token in new_tokens: @@ -438,8 +422,8 @@ class AudioProcessor: sep=self.sep ) if end_w_silence: - buffer_transcription = '' - buffer_diarization = '' + buffer_transcription = Transcript() + buffer_diarization = Transcript() else: buffer_transcription = state.buffer_transcription buffer_diarization = state.buffer_diarization @@ -449,8 +433,13 @@ class AudioProcessor: combined = self.sep.join(undiarized_text) if buffer_transcription: combined += self.sep - await self.update_diarization(state.end_attributed_speaker, combined) - buffer_diarization = combined + + async with self.lock: + self.end_attributed_speaker = state.end_attributed_speaker + if buffer_diarization: + self.buffer_diarization = buffer_diarization + + buffer_diarization.text = combined response_status = "active_transcription" if not state.tokens and not buffer_transcription and not buffer_diarization: @@ -466,8 +455,8 @@ class AudioProcessor: response = FrontData( status=response_status, lines=lines, - buffer_transcription=buffer_transcription, - buffer_diarization=buffer_diarization, + buffer_transcription=buffer_transcription.text, + buffer_diarization=buffer_transcription.text, remaining_time_transcription=state.remaining_time_transcription, remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 ) diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 8f30a9a..226e646 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -150,4 +150,8 @@ def format_output(state, silence, current_time, args, debug, sep): else: remaining_segments.append(ts) unassigned_translated_segments = remaining_segments #maybe do smth in the future about that + + if state.buffer_transcription and lines: + lines[-1].end = max(state.buffer_transcription.end, lines[-1].end) + return lines, undiarized_text, end_w_silence diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index ab648db..a4f3aa2 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -52,12 +52,7 @@ class SimulStreamingOnlineProcessor: self.asr = asr self.logfile = logfile self.end = 0.0 - self.buffer = Transcript( - start=None, - end=None, - text='', - probability=None - ) + self.buffer = [] self.committed: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = [] self.load_new_backend() @@ -103,8 +98,9 @@ class SimulStreamingOnlineProcessor: self.model.refresh_segment(complete=True) def get_buffer(self): - return self.buffer - + concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='') + return concat_buffer + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: """ Process accumulated audio chunks using SimulStreaming. @@ -112,9 +108,10 @@ class SimulStreamingOnlineProcessor: Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). """ try: - new_tokens = self.model.infer(is_last=is_last) - self.committed.extend(new_tokens) - return new_tokens, self.end + timestamped_words, timestamped_buffer_language = self.model.infer(is_last=is_last) + self.buffer = timestamped_buffer_language + self.committed.extend(timestamped_words) + return timestamped_words, self.end except Exception as e: diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index b23a57c..0d40bf7 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -74,6 +74,7 @@ class PaddedAlignAttWhisper: ) self.tokenizer_is_multilingual = not model_name.endswith(".en") self.create_tokenizer(cfg.language if cfg.language != "auto" else None) + # self.create_tokenizer('en') self.detected_language = cfg.language if cfg.language != "auto" else None self.global_time_offset = 0.0 self.reset_tokenizer_to_auto_next_call = False @@ -433,21 +434,18 @@ class PaddedAlignAttWhisper: end_encode = time() # print('Encoder duration:', end_encode-beg_encode) - # if self.cfg.language == "auto" and self.detected_language is None: - # seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time - # if seconds_since_start >= 3.0: - # language_tokens, language_probs = self.lang_id(encoder_feature) - # logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}") - # top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) - # logger.info(f"Detected language: {top_lan} with p={p:.4f}") - # #self.tokenizer.language = top_lan - # #self.tokenizer.__post_init__() - # self.create_tokenizer(top_lan) - # self.detected_language = top_lan - # self.init_tokens() - # logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") - # else: - # logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s") + if self.cfg.language == "auto" and self.detected_language is None: + seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time + if seconds_since_start >= 3.0: + language_tokens, language_probs = self.lang_id(encoder_feature) + top_lan, p = max(language_probs[0].items(), key=lambda x: x[1]) + print(f"Detected language: {top_lan} with p={p:.4f}") + self.create_tokenizer(top_lan) + self.refresh_segment(complete=True) + self.detected_language = top_lan + logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}") + else: + logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s") self.trim_context() current_tokens = self._current_tokens() @@ -495,19 +493,6 @@ class PaddedAlignAttWhisper: logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ") self.debug_print_tokens(current_tokens) - # # Early stop on sentence-ending punctuation when language is auto - # if not completed and self.cfg.language == "auto": - # last_token_id = current_tokens[0, -1].item() - # last_token_text = self.tokenizer.decode([last_token_id]).strip() - # if last_token_text in PUNCTUATION_MARKS: - # logger.debug(f"Punctuation boundary '{last_token_text}' hit; stopping early to allow language re-check.") - # punctuation_stop = True - # # Ensure next call starts with auto language (re-detect for new sentence) - # self.reset_tokenizer_to_auto_next_call = True - # self.detected_language = None - # self.sentence_start_time = self.cumulative_time_offset + self.segments_len() - # break - attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)] for i, attn_mat in enumerate(self.dec_attns): layer_rank = int(i % len(self.model.decoder.blocks)) @@ -617,10 +602,17 @@ class PaddedAlignAttWhisper: timestamp_entry = ASRToken( start=current_timestamp, end=current_timestamp + 0.1, - text=word, - probability=0.95 + text= word, + probability=0.95, + language=self.detected_language ).with_offset( self.global_time_offset ) timestamped_words.append(timestamp_entry) - return timestamped_words \ No newline at end of file + + if self.detected_language is None and self.cfg.language == "auto": + timestamped_buffer_language, timestamped_words = timestamped_words, [] + else: + timestamped_buffer_language = [] + + return timestamped_words, timestamped_buffer_language diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 722673f..110e98d 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Any +from typing import Optional, Any, List from datetime import timedelta PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'} @@ -17,6 +17,7 @@ class TimedText: speaker: Optional[int] = -1 probability: Optional[float] = None is_dummy: Optional[bool] = False + language: str = None def is_punctuation(self): return self.text.strip() in PUNCTUATION_MARKS @@ -35,6 +36,10 @@ class TimedText: def contains_timespan(self, other: 'TimedText') -> bool: return self.start <= other.start and self.end >= other.end + + def __bool__(self): + return bool(self.text) + @dataclass class ASRToken(TimedText): @@ -48,7 +53,28 @@ class Sentence(TimedText): @dataclass class Transcript(TimedText): - pass + """ + represents a concatenation of several ASRToken + """ + + @classmethod + def from_tokens( + cls, + tokens: List[ASRToken], + sep: Optional[str] = None, + offset: float = 0 + ) -> "Transcript": + sep = sep if sep is not None else ' ' + text = sep.join(token.text for token in tokens) + probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None + if tokens: + start = offset + tokens[0].start + end = offset + tokens[-1].end + else: + start = None + end = None + return cls(start, end, text, probability=probability) + @dataclass class SpeakerSegment(TimedText):