correct timing (lag) calculations in SimulStreamingASR and SimulStreamingOnlineProcessor

This commit is contained in:
Quentin Fuxa
2025-06-26 00:13:44 +02:00
parent bfec335a5f
commit 8e30e8010a
2 changed files with 39 additions and 22 deletions

View File

@@ -425,8 +425,10 @@ class SimulStreamingASR(ASRBase):
# We dont have word-level timestamps here. 1rst approach, should be improved later.
words = text.strip().split()
if not words:
return tokens
duration_per_word = 0.5 # rough estimate of 0.5 seconds per word... Not so great
return tokens
duration_per_word = 0.1 # this will be modified based on actual audio duration
#with the SimulStreamingOnlineProcessor
for i, word in enumerate(words):
start_time = i * duration_per_word

View File

@@ -544,10 +544,12 @@ class SimulStreamingOnlineProcessor:
self.beg = self.offset
self.end = self.offset
self.cumulative_audio_duration = 0.0
self.last_audio_stream_end_time = self.offset
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.processed_audio_duration = 0.0
def get_audio_buffer_end_time(self) -> float:
"""Returns the absolute end time of the current audio buffer."""
@@ -565,7 +567,12 @@ class SimulStreamingOnlineProcessor:
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
self.end = self.offset + self.cumulative_audio_duration
if audio_stream_end_time is not None:
self.last_audio_stream_end_time = audio_stream_end_time
self.end = audio_stream_end_time
else:
self.end = self.offset + self.cumulative_audio_duration
def prompt(self) -> Tuple[str, str]:
"""
@@ -578,9 +585,10 @@ class SimulStreamingOnlineProcessor:
"""
Get the unvalidated buffer content.
"""
buffer_end = self.end if hasattr(self, 'end') else None
return Transcript(
start=None,
end=None,
end=buffer_end,
text=self.buffer_content,
probability=None
)
@@ -601,12 +609,13 @@ class SimulStreamingOnlineProcessor:
else:
audio = torch.cat(self.audio_chunks, dim=0)
if audio.shape[0] > 0:
self.end = self.offset + (audio.shape[0] / self.SAMPLING_RATE)
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
self.processed_audio_duration += audio_duration
self.audio_chunks = []
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}")
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
result = self.asr.model.infer(audio, is_last=self.is_last)
@@ -624,27 +633,33 @@ class SimulStreamingOnlineProcessor:
words = decoded_text.strip().split()
new_tokens = []
current_time = self.beg
word_duration = 0.3 # Not great should be improved.
for word in words:
token_start = current_time
token_end = current_time + word_duration
token = ASRToken(
start=token_start,
end=token_end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
current_time = token_end
num_words = len(words)
if num_words > 0:
# distribute words evenly across the processed audio duration
# we NEED that for when we use diarization. Even if that s not perfect
start_time = self.end - audio_duration
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
for i, word in enumerate(words):
token_start = start_time + (i * time_per_word)
token_end = start_time + ((i + 1) * time_per_word)
token_end = min(token_end, self.end)
token = ASRToken(
start=token_start,
end=token_end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
)
new_tokens.append(token)
self.beg = self.end
self.committed.extend(new_tokens)
self.last_result_tokens = new_tokens
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens")
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
return new_tokens, self.end
return [], self.end