O(n) to O(1) for simulstreaming timestamp determination

This commit is contained in:
Quentin Fuxa
2025-09-21 11:04:00 +02:00
parent e61afdefa3
commit a5503308c5
5 changed files with 98 additions and 71 deletions

View File

@@ -429,7 +429,7 @@ class AudioProcessor:
state = await self.get_current_state()
# Format output
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
lines, undiarized_text, end_w_silence = format_output(
state,
self.silence,
current_time = time() - self.beg_loop if self.beg_loop else None,
@@ -437,6 +437,13 @@ class AudioProcessor:
debug = self.debug,
sep=self.sep
)
if end_w_silence:
buffer_transcription = ''
buffer_diarization = ''
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
# Handle undiarized text
if undiarized_text:
combined = self.sep.join(undiarized_text)

View File

@@ -77,15 +77,17 @@ def no_token_to_silence(tokens):
new_tokens.append(token)
return new_tokens
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
def ends_with_silence(tokens, current_time, vac_detected_silence):
end_w_silence = False
if not tokens:
return [], buffer_transcription, buffer_diarization
return [], end_w_silence
last_token = tokens[-1]
if tokens and current_time and (
current_time - last_token.end >= END_SILENCE_DURATION
or
or
(current_time - last_token.end >= 3 and vac_detected_silence)
):
end_w_silence = True
if last_token.speaker == -2:
last_token.end = current_time
else:
@@ -97,14 +99,12 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
probability=0.95
)
)
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
buffer_diarization = ""
return tokens, buffer_transcription, buffer_diarization
return tokens, end_w_silence
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
def handle_silences(tokens, current_time, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
return tokens, buffer_transcription, buffer_diarization
tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence)
return tokens, end_w_silence

View File

@@ -50,14 +50,12 @@ def format_output(state, silence, current_time, args, debug, sep):
disable_punctuation_split = args.disable_punctuation_split
tokens = state.tokens
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
buffer_transcription = state.buffer_transcription
buffer_diarization = state.buffer_diarization
end_attributed_speaker = state.end_attributed_speaker
previous_speaker = -1
lines = []
undiarized_text = []
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
tokens, end_w_silence = handle_silences(tokens, current_time, silence)
last_punctuation = None
for i, token in enumerate(tokens):
speaker = token.speaker
@@ -121,6 +119,7 @@ def format_output(state, silence, current_time, args, debug, sep):
pass
append_token_to_last_line(lines, sep, token, debug_info)
if lines and translated_segments:
unassigned_translated_segments = []
for ts in translated_segments:
@@ -151,4 +150,4 @@ def format_output(state, silence, current_time, args, debug, sep):
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
return lines, undiarized_text, buffer_transcription, ''
return lines, undiarized_text, end_w_silence

View File

@@ -6,7 +6,6 @@ import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
@@ -23,7 +22,11 @@ try:
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
print(f"""
{"="*50}
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
{"="*50}
""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
@@ -49,8 +52,12 @@ class SimulStreamingOnlineProcessor:
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.global_time_offset = 0.0
self.buffer = Transcript(
start=None,
end=None,
text='',
probability=None
)
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
@@ -79,7 +86,7 @@ class SimulStreamingOnlineProcessor:
else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.model.refresh_segment(complete=True)
self.global_time_offset = silence_duration + offset
self.model.global_time_offset = silence_duration + offset
@@ -96,31 +103,7 @@ class SimulStreamingOnlineProcessor:
self.model.refresh_segment(complete=True)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, split_words, split_tokens, l_absolute_timestamps):
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
for i in word_tokens:
current_timestamp = l_absolute_timestamps.pop(0)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text=word,
probability=0.95
).with_offset(
self.global_time_offset
)
timestamped_words.append(timestamp_entry)
return timestamped_words
return self.buffer
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
@@ -129,9 +112,7 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
split_words, split_tokens, l_absolute_timestamps = self.model.infer(is_last=is_last)
new_tokens = self.timestamped_text(split_words, split_tokens, l_absolute_timestamps)
new_tokens = self.model.infer(is_last=is_last)
self.committed.extend(new_tokens)
return new_tokens, self.end
@@ -163,7 +144,6 @@ class SimulStreamingASR():
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan

View File

@@ -8,6 +8,7 @@ import torch.nn.functional as F
from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig
from whisperlivekit.timed_objects import ASRToken
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
@@ -18,6 +19,7 @@ from time import time
from .token_buffer import TokenBuffer
import numpy as np
from ..timed_objects import PUNCTUATION_MARKS
from .generation_progress import *
DEC_PAD = 50257
@@ -40,12 +42,6 @@ else:
except ImportError:
HAS_FASTER_WHISPER = False
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
# - translation support
# - beam search
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(
self,
@@ -79,6 +75,9 @@ class PaddedAlignAttWhisper:
self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.detected_language = cfg.language if cfg.language != "auto" else None
self.global_time_offset = 0.0
self.reset_tokenizer_to_auto_next_call = False
self.sentence_start_time = 0.0
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
@@ -153,6 +152,7 @@ class PaddedAlignAttWhisper:
self.last_attend_frame = -self.cfg.rewind_threshold
self.cumulative_time_offset = 0.0
self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
@@ -382,11 +382,11 @@ class PaddedAlignAttWhisper:
new_segment = True
if len(self.segments) == 0:
logger.debug("No segments, nothing to do")
return [], [], []
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.segments, dim=0)
return [], [], []
return []
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
@@ -394,6 +394,13 @@ class PaddedAlignAttWhisper:
else:
input_segments = self.segments[0]
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
# logger.debug("Resetting tokenizer to auto for new sentence.")
# self.create_tokenizer(None)
# self.detected_language = None
# self.init_tokens()
# self.reset_tokenizer_to_auto_next_call = False
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if self.mlx_encoder:
@@ -426,17 +433,21 @@ class PaddedAlignAttWhisper:
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)
if self.cfg.language == "auto" and self.detected_language is None:
language_tokens, language_probs = self.lang_id(encoder_feature)
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
#self.tokenizer.language = top_lan
#self.tokenizer.__post_init__()
self.create_tokenizer(top_lan)
self.detected_language = top_lan
self.init_tokens()
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
# if self.cfg.language == "auto" and self.detected_language is None:
# seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
# if seconds_since_start >= 3.0:
# language_tokens, language_probs = self.lang_id(encoder_feature)
# logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
# top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
# logger.info(f"Detected language: {top_lan} with p={p:.4f}")
# #self.tokenizer.language = top_lan
# #self.tokenizer.__post_init__()
# self.create_tokenizer(top_lan)
# self.detected_language = top_lan
# self.init_tokens()
# logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
# else:
# logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
self.trim_context()
current_tokens = self._current_tokens()
@@ -446,6 +457,7 @@ class PaddedAlignAttWhisper:
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
completed = False
# punctuation_stop = False
attn_of_alignment_heads = None
most_attended_frame = None
@@ -467,9 +479,7 @@ class PaddedAlignAttWhisper:
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# generation["no_speech_prob"] = no_speech_probs[0]
if no_speech_probs[0] > self.cfg.nonspeech_prob:
# generation["no_speech"] = True
logger.info("no speech, stop")
break
@@ -485,6 +495,19 @@ class PaddedAlignAttWhisper:
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
# # Early stop on sentence-ending punctuation when language is auto
# if not completed and self.cfg.language == "auto":
# last_token_id = current_tokens[0, -1].item()
# last_token_text = self.tokenizer.decode([last_token_id]).strip()
# if last_token_text in PUNCTUATION_MARKS:
# logger.debug(f"Punctuation boundary '{last_token_text}' hit; stopping early to allow language re-check.")
# punctuation_stop = True
# # Ensure next call starts with auto language (re-detect for new sentence)
# self.reset_tokenizer_to_auto_next_call = True
# self.detected_language = None
# self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
# break
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks))
@@ -560,7 +583,7 @@ class PaddedAlignAttWhisper:
tokens_to_split = current_tokens[0, token_len_before_decoding:]
if fire_detected or is_last:
if fire_detected or is_last: #or punctuation_stop:
new_hypothesis = tokens_to_split.flatten().tolist()
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
@@ -582,4 +605,22 @@ class PaddedAlignAttWhisper:
self._clean_cache()
return split_words, split_tokens, l_absolute_timestamps
timestamped_words = []
timestamp_idx = 0
for word, word_tokens in zip(split_words, split_tokens):
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=current_timestamp,
end=current_timestamp + 0.1,
text=word,
probability=0.95
).with_offset(
self.global_time_offset
)
timestamped_words.append(timestamp_entry)
return timestamped_words