mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
correct timing (lag) calculations in SimulStreamingASR and SimulStreamingOnlineProcessor
This commit is contained in:
@@ -425,8 +425,10 @@ class SimulStreamingASR(ASRBase):
|
||||
# We dont have word-level timestamps here. 1rst approach, should be improved later.
|
||||
words = text.strip().split()
|
||||
if not words:
|
||||
return tokens
|
||||
duration_per_word = 0.5 # rough estimate of 0.5 seconds per word... Not so great
|
||||
return tokens
|
||||
|
||||
duration_per_word = 0.1 # this will be modified based on actual audio duration
|
||||
#with the SimulStreamingOnlineProcessor
|
||||
|
||||
for i, word in enumerate(words):
|
||||
start_time = i * duration_per_word
|
||||
|
||||
@@ -544,10 +544,12 @@ class SimulStreamingOnlineProcessor:
|
||||
self.beg = self.offset
|
||||
self.end = self.offset
|
||||
self.cumulative_audio_duration = 0.0
|
||||
self.last_audio_stream_end_time = self.offset
|
||||
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.buffer_content = ""
|
||||
self.processed_audio_duration = 0.0
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio buffer."""
|
||||
@@ -565,7 +567,12 @@ class SimulStreamingOnlineProcessor:
|
||||
# Update timing
|
||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
||||
self.cumulative_audio_duration += chunk_duration
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
|
||||
if audio_stream_end_time is not None:
|
||||
self.last_audio_stream_end_time = audio_stream_end_time
|
||||
self.end = audio_stream_end_time
|
||||
else:
|
||||
self.end = self.offset + self.cumulative_audio_duration
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
@@ -578,9 +585,10 @@ class SimulStreamingOnlineProcessor:
|
||||
"""
|
||||
Get the unvalidated buffer content.
|
||||
"""
|
||||
buffer_end = self.end if hasattr(self, 'end') else None
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=None,
|
||||
end=buffer_end,
|
||||
text=self.buffer_content,
|
||||
probability=None
|
||||
)
|
||||
@@ -601,12 +609,13 @@ class SimulStreamingOnlineProcessor:
|
||||
else:
|
||||
audio = torch.cat(self.audio_chunks, dim=0)
|
||||
|
||||
if audio.shape[0] > 0:
|
||||
self.end = self.offset + (audio.shape[0] / self.SAMPLING_RATE)
|
||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
||||
self.processed_audio_duration += audio_duration
|
||||
|
||||
self.audio_chunks = []
|
||||
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}")
|
||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
||||
|
||||
result = self.asr.model.infer(audio, is_last=self.is_last)
|
||||
|
||||
@@ -624,27 +633,33 @@ class SimulStreamingOnlineProcessor:
|
||||
words = decoded_text.strip().split()
|
||||
new_tokens = []
|
||||
|
||||
current_time = self.beg
|
||||
word_duration = 0.3 # Not great should be improved.
|
||||
|
||||
for word in words:
|
||||
token_start = current_time
|
||||
token_end = current_time + word_duration
|
||||
token = ASRToken(
|
||||
start=token_start,
|
||||
end=token_end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
current_time = token_end
|
||||
num_words = len(words)
|
||||
if num_words > 0:
|
||||
# distribute words evenly across the processed audio duration
|
||||
# we NEED that for when we use diarization. Even if that s not perfect
|
||||
start_time = self.end - audio_duration
|
||||
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
|
||||
|
||||
for i, word in enumerate(words):
|
||||
token_start = start_time + (i * time_per_word)
|
||||
token_end = start_time + ((i + 1) * time_per_word)
|
||||
|
||||
token_end = min(token_end, self.end)
|
||||
|
||||
token = ASRToken(
|
||||
start=token_start,
|
||||
end=token_end,
|
||||
text=word,
|
||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
||||
)
|
||||
new_tokens.append(token)
|
||||
|
||||
self.beg = self.end
|
||||
|
||||
self.committed.extend(new_tokens)
|
||||
self.last_result_tokens = new_tokens
|
||||
|
||||
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens")
|
||||
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
|
||||
return new_tokens, self.end
|
||||
|
||||
return [], self.end
|
||||
|
||||
Reference in New Issue
Block a user