simulstreaming: cumulative_time_offset to keep timestamps correct when audio > 30s

This commit is contained in:
Quentin Fuxa
2025-08-17 09:33:47 +02:00
parent 820f92d8cb
commit d0e9e37ef6
3 changed files with 30 additions and 17 deletions

View File

@@ -112,10 +112,10 @@ def parse_args():
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
# action="store_true",
default=True,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
"--no-vac",
action="store_true",
default=False,
help="Disable VAC = voice activity controller.",
)
parser.add_argument(
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."

View File

@@ -59,9 +59,10 @@ class SimulStreamingOnlineProcessor:
if silence_duration < 5:
gap_silence = torch.zeros(int(16000*min(silence_duration, 1.0)))
self.model.insert_audio(gap_silence)
self.global_time_offset = silence_duration - 1.0
else:
self.model.refresh_segment(complete=True)
self.global_time_offset += silence_duration
self.global_time_offset += silence_duration
@@ -100,29 +101,32 @@ class SimulStreamingOnlineProcessor:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
progress = generation["progress"]
frames = [p["most_attended_frames"][0] for p in progress]
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
tokens_queue = tokens.copy()
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
start_frame = None
end_frame = None
# start_frame = None
# end_frame = None
for expected_token in word_tokens:
if not tokens_queue or not frames:
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
actual_token = tokens_queue.pop(0)
current_frame = frames.pop(0)
current_timestamp = absolute_timestamps.pop(0)
if actual_token != expected_token:
raise ValueError(
f"Token mismatch: expected '{expected_token}', "
f"got '{actual_token}' at frame {current_frame}"
)
if start_frame is None:
start_frame = current_frame
end_frame = current_frame
start_time = start_frame * FRAME_DURATION
end_time = end_frame * FRAME_DURATION
# if start_frame is None:
# start_frame = current_frame
# end_frame = current_frame
# start_time = start_frame * FRAME_DURATION
# end_time = end_frame * FRAME_DURATION
start_time = current_timestamp
end_time = current_timestamp + 0.1
timestamp_entry = (start_time, end_time, word)
timestamped_words.append(timestamp_entry)
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
@@ -134,7 +138,7 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
try:
tokens, generation_progress = self.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
@@ -212,7 +216,7 @@ class SimulStreamingASR():
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)

View File

@@ -125,6 +125,7 @@ class PaddedAlignAttWhisper:
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
@@ -220,6 +221,7 @@ class PaddedAlignAttWhisper:
self.init_tokens()
self.last_attend_frame = -self.cfg.rewind_threshold
self.detected_language = None
self.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
@@ -287,8 +289,9 @@ class PaddedAlignAttWhisper:
removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.cumulative_time_offset += removed_len # Track cumulative time removed
self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
if len(self.tokens) > 1:
self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
@@ -504,7 +507,13 @@ class PaddedAlignAttWhisper:
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
# Calculate absolute timestamps accounting for cumulative offset
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
generation_progress_loop.append(("absolute_timestamps", absolute_timestamps))
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
most_attended_frame = most_attended_frames[0].item()
@@ -609,4 +618,4 @@ class PaddedAlignAttWhisper:
self._clean_cache()
return new_hypothesis, generation
return new_hypothesis, generation