diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 391b91b..8335f55 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -112,10 +112,10 @@ def parse_args(): help="Load only this backend for Whisper processing.", ) parser.add_argument( - "--vac", - # action="store_true", - default=True, - help="Use VAC = voice activity controller. Recommended. Requires torch.", + "--no-vac", + action="store_true", + default=False, + help="Disable VAC = voice activity controller.", ) parser.add_argument( "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds." diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 3573aca..d0eefa6 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -59,9 +59,10 @@ class SimulStreamingOnlineProcessor: if silence_duration < 5: gap_silence = torch.zeros(int(16000*min(silence_duration, 1.0))) self.model.insert_audio(gap_silence) + self.global_time_offset = silence_duration - 1.0 else: self.model.refresh_segment(complete=True) - self.global_time_offset += silence_duration + self.global_time_offset += silence_duration @@ -100,29 +101,32 @@ class SimulStreamingOnlineProcessor: split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens) progress = generation["progress"] frames = [p["most_attended_frames"][0] for p in progress] + absolute_timestamps = [p["absolute_timestamps"][0] for p in progress] tokens_queue = tokens.copy() timestamped_words = [] for word, word_tokens in zip(split_words, split_tokens): - start_frame = None - end_frame = None + # start_frame = None + # end_frame = None for expected_token in word_tokens: if not tokens_queue or not frames: raise ValueError(f"Insufficient tokens or frames for word '{word}'") actual_token = tokens_queue.pop(0) current_frame = frames.pop(0) + current_timestamp = absolute_timestamps.pop(0) if actual_token != expected_token: raise ValueError( f"Token mismatch: expected '{expected_token}', " f"got '{actual_token}' at frame {current_frame}" ) - if start_frame is None: - start_frame = current_frame - end_frame = current_frame - start_time = start_frame * FRAME_DURATION - end_time = end_frame * FRAME_DURATION - + # if start_frame is None: + # start_frame = current_frame + # end_frame = current_frame + # start_time = start_frame * FRAME_DURATION + # end_time = end_frame * FRAME_DURATION + start_time = current_timestamp + end_time = current_timestamp + 0.1 timestamp_entry = (start_time, end_time, word) timestamped_words.append(timestamp_entry) logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}") @@ -134,7 +138,7 @@ class SimulStreamingOnlineProcessor: Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time). """ - try: + try: tokens, generation_progress = self.model.infer(is_last=self.is_last) ts_words = self.timestamped_text(tokens, generation_progress) @@ -212,7 +216,7 @@ class SimulStreamingASR(): self.model_path = kwargs.get('model_path', './large-v3.pt') self.frame_threshold = kwargs.get('frame_threshold', 25) - self.audio_max_len = kwargs.get('audio_max_len', 30.0) + self.audio_max_len = kwargs.get('audio_max_len', 20.0) self.audio_min_len = kwargs.get('audio_min_len', 0.0) self.segment_length = kwargs.get('segment_length', 0.5) self.beams = kwargs.get('beams', 1) diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index aa3d794..cad2143 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -125,6 +125,7 @@ class PaddedAlignAttWhisper: self.init_tokens() self.last_attend_frame = -self.cfg.rewind_threshold + self.cumulative_time_offset = 0.0 if self.cfg.max_context_tokens is None: self.max_context_tokens = self.max_text_len @@ -220,6 +221,7 @@ class PaddedAlignAttWhisper: self.init_tokens() self.last_attend_frame = -self.cfg.rewind_threshold self.detected_language = None + self.cumulative_time_offset = 0.0 self.init_context() logger.debug(f"Context: {self.context}") if not complete and len(self.segments) > 2: @@ -287,8 +289,9 @@ class PaddedAlignAttWhisper: removed_len = self.segments[0].shape[0] / 16000 segments_len -= removed_len self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len) + self.cumulative_time_offset += removed_len # Track cumulative time removed self.segments = self.segments[1:] - logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}") + logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s") if len(self.tokens) > 1: self.context.append_token_ids(self.tokens[1][0,:]) self.tokens = [self.initial_tokens] + self.tokens[2:] @@ -504,7 +507,13 @@ class PaddedAlignAttWhisper: # for each beam, the most attended frame is: most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1) generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist())) + + # Calculate absolute timestamps accounting for cumulative offset + absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()] + generation_progress_loop.append(("absolute_timestamps", absolute_timestamps)) + logger.debug(str(most_attended_frames.tolist()) + " most att frames") + logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)") most_attended_frame = most_attended_frames[0].item() @@ -609,4 +618,4 @@ class PaddedAlignAttWhisper: self._clean_cache() - return new_hypothesis, generation \ No newline at end of file + return new_hypothesis, generation