mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
O(n) to O(1) for simulstreaming timestamp determination
This commit is contained in:
@@ -429,7 +429,7 @@ class AudioProcessor:
|
||||
state = await self.get_current_state()
|
||||
|
||||
# Format output
|
||||
lines, undiarized_text, buffer_transcription, buffer_diarization = format_output(
|
||||
lines, undiarized_text, end_w_silence = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
current_time = time() - self.beg_loop if self.beg_loop else None,
|
||||
@@ -437,6 +437,13 @@ class AudioProcessor:
|
||||
debug = self.debug,
|
||||
sep=self.sep
|
||||
)
|
||||
if end_w_silence:
|
||||
buffer_transcription = ''
|
||||
buffer_diarization = ''
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
buffer_diarization = state.buffer_diarization
|
||||
|
||||
# Handle undiarized text
|
||||
if undiarized_text:
|
||||
combined = self.sep.join(undiarized_text)
|
||||
|
||||
@@ -77,15 +77,17 @@ def no_token_to_silence(tokens):
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
def ends_with_silence(tokens, current_time, vac_detected_silence):
|
||||
end_w_silence = False
|
||||
if not tokens:
|
||||
return [], buffer_transcription, buffer_diarization
|
||||
return [], end_w_silence
|
||||
last_token = tokens[-1]
|
||||
if tokens and current_time and (
|
||||
current_time - last_token.end >= END_SILENCE_DURATION
|
||||
or
|
||||
or
|
||||
(current_time - last_token.end >= 3 and vac_detected_silence)
|
||||
):
|
||||
end_w_silence = True
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
@@ -97,14 +99,12 @@ def ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_
|
||||
probability=0.95
|
||||
)
|
||||
)
|
||||
buffer_transcription = "" # for whisperstreaming backend, we should probably validate the buffer has because of the silence
|
||||
buffer_diarization = ""
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
return tokens, end_w_silence
|
||||
|
||||
|
||||
def handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence):
|
||||
def handle_silences(tokens, current_time, vac_detected_silence):
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens, buffer_transcription, buffer_diarization = ends_with_silence(tokens, buffer_transcription, buffer_diarization, current_time, vac_detected_silence)
|
||||
return tokens, buffer_transcription, buffer_diarization
|
||||
tokens, end_w_silence = ends_with_silence(tokens, current_time, vac_detected_silence)
|
||||
return tokens, end_w_silence
|
||||
|
||||
@@ -50,14 +50,12 @@ def format_output(state, silence, current_time, args, debug, sep):
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translated_segments = state.translated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
buffer_transcription = state.buffer_transcription
|
||||
buffer_diarization = state.buffer_diarization
|
||||
end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
previous_speaker = -1
|
||||
lines = []
|
||||
undiarized_text = []
|
||||
tokens, buffer_transcription, buffer_diarization = handle_silences(tokens, buffer_transcription, buffer_diarization, current_time, silence)
|
||||
tokens, end_w_silence = handle_silences(tokens, current_time, silence)
|
||||
last_punctuation = None
|
||||
for i, token in enumerate(tokens):
|
||||
speaker = token.speaker
|
||||
@@ -121,6 +119,7 @@ def format_output(state, silence, current_time, args, debug, sep):
|
||||
pass
|
||||
|
||||
append_token_to_last_line(lines, sep, token, debug_info)
|
||||
|
||||
if lines and translated_segments:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translated_segments:
|
||||
@@ -151,4 +150,4 @@ def format_output(state, silence, current_time, args, debug, sep):
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
return lines, undiarized_text, buffer_transcription, ''
|
||||
return lines, undiarized_text, end_w_silence
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, SpeakerSegment
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
|
||||
from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
@@ -23,7 +22,11 @@ try:
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
|
||||
print(f"""
|
||||
{"="*50}
|
||||
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
|
||||
{"="*50}
|
||||
""")
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
@@ -49,8 +52,12 @@ class SimulStreamingOnlineProcessor:
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
self.buffer = Transcript(
|
||||
start=None,
|
||||
end=None,
|
||||
text='',
|
||||
probability=None
|
||||
)
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_backend()
|
||||
@@ -79,7 +86,7 @@ class SimulStreamingOnlineProcessor:
|
||||
else:
|
||||
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.global_time_offset = silence_duration + offset
|
||||
self.model.global_time_offset = silence_duration + offset
|
||||
|
||||
|
||||
|
||||
@@ -96,31 +103,7 @@ class SimulStreamingOnlineProcessor:
|
||||
self.model.refresh_segment(complete=True)
|
||||
|
||||
def get_buffer(self):
|
||||
return Transcript(
|
||||
start=None,
|
||||
end=None,
|
||||
text='',
|
||||
probability=None
|
||||
)
|
||||
|
||||
def timestamped_text(self, split_words, split_tokens, l_absolute_timestamps):
|
||||
timestamped_words = []
|
||||
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
|
||||
for i in word_tokens:
|
||||
current_timestamp = l_absolute_timestamps.pop(0)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=current_timestamp,
|
||||
end=current_timestamp + 0.1,
|
||||
text=word,
|
||||
probability=0.95
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
return timestamped_words
|
||||
return self.buffer
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
@@ -129,9 +112,7 @@ class SimulStreamingOnlineProcessor:
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
split_words, split_tokens, l_absolute_timestamps = self.model.infer(is_last=is_last)
|
||||
new_tokens = self.timestamped_text(split_words, split_tokens, l_absolute_timestamps)
|
||||
|
||||
new_tokens = self.model.infer(is_last=is_last)
|
||||
self.committed.extend(new_tokens)
|
||||
return new_tokens, self.end
|
||||
|
||||
@@ -163,7 +144,6 @@ class SimulStreamingASR():
|
||||
sep = ""
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
||||
logger.warning(SIMULSTREAMING_LICENSE)
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
self.original_language = lan
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from .whisper import load_model, DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from .whisper.timing import median_filter
|
||||
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||
@@ -18,6 +19,7 @@ from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
|
||||
import numpy as np
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
from .generation_progress import *
|
||||
|
||||
DEC_PAD = 50257
|
||||
@@ -40,12 +42,6 @@ else:
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
# New features added to the original version of Simul-Whisper:
|
||||
# - large-v3 model support
|
||||
# - translation support
|
||||
# - beam search
|
||||
# - prompt -- static vs. non-static
|
||||
# - context
|
||||
class PaddedAlignAttWhisper:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -79,6 +75,9 @@ class PaddedAlignAttWhisper:
|
||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.global_time_offset = 0.0
|
||||
self.reset_tokenizer_to_auto_next_call = False
|
||||
self.sentence_start_time = 0.0
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
@@ -153,6 +152,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
@@ -382,11 +382,11 @@ class PaddedAlignAttWhisper:
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
logger.debug("No segments, nothing to do")
|
||||
return [], [], []
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
return [], [], []
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
@@ -394,6 +394,13 @@ class PaddedAlignAttWhisper:
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
|
||||
# logger.debug("Resetting tokenizer to auto for new sentence.")
|
||||
# self.create_tokenizer(None)
|
||||
# self.detected_language = None
|
||||
# self.init_tokens()
|
||||
# self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||
beg_encode = time()
|
||||
if self.mlx_encoder:
|
||||
@@ -426,17 +433,21 @@ class PaddedAlignAttWhisper:
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
if self.cfg.language == "auto" and self.detected_language is None:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
#self.tokenizer.language = top_lan
|
||||
#self.tokenizer.__post_init__()
|
||||
self.create_tokenizer(top_lan)
|
||||
self.detected_language = top_lan
|
||||
self.init_tokens()
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
# if self.cfg.language == "auto" and self.detected_language is None:
|
||||
# seconds_since_start = (self.cumulative_time_offset + self.segments_len()) - self.sentence_start_time
|
||||
# if seconds_since_start >= 3.0:
|
||||
# language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
# logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
||||
# top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
# logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
# #self.tokenizer.language = top_lan
|
||||
# #self.tokenizer.__post_init__()
|
||||
# self.create_tokenizer(top_lan)
|
||||
# self.detected_language = top_lan
|
||||
# self.init_tokens()
|
||||
# logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
# else:
|
||||
# logger.debug(f"Skipping language detection: {seconds_since_start:.2f}s < 3.0s")
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
@@ -446,6 +457,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||
completed = False
|
||||
# punctuation_stop = False
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
most_attended_frame = None
|
||||
@@ -467,9 +479,7 @@ class PaddedAlignAttWhisper:
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
# generation["no_speech_prob"] = no_speech_probs[0]
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
# generation["no_speech"] = True
|
||||
logger.info("no speech, stop")
|
||||
break
|
||||
|
||||
@@ -485,6 +495,19 @@ class PaddedAlignAttWhisper:
|
||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
# # Early stop on sentence-ending punctuation when language is auto
|
||||
# if not completed and self.cfg.language == "auto":
|
||||
# last_token_id = current_tokens[0, -1].item()
|
||||
# last_token_text = self.tokenizer.decode([last_token_id]).strip()
|
||||
# if last_token_text in PUNCTUATION_MARKS:
|
||||
# logger.debug(f"Punctuation boundary '{last_token_text}' hit; stopping early to allow language re-check.")
|
||||
# punctuation_stop = True
|
||||
# # Ensure next call starts with auto language (re-detect for new sentence)
|
||||
# self.reset_tokenizer_to_auto_next_call = True
|
||||
# self.detected_language = None
|
||||
# self.sentence_start_time = self.cumulative_time_offset + self.segments_len()
|
||||
# break
|
||||
|
||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||
for i, attn_mat in enumerate(self.dec_attns):
|
||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||
@@ -560,7 +583,7 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
|
||||
if fire_detected or is_last:
|
||||
if fire_detected or is_last: #or punctuation_stop:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
else:
|
||||
@@ -582,4 +605,22 @@ class PaddedAlignAttWhisper:
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
return split_words, split_tokens, l_absolute_timestamps
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
try:
|
||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||
except:
|
||||
pass
|
||||
timestamp_idx += len(word_tokens)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=current_timestamp,
|
||||
end=current_timestamp + 0.1,
|
||||
text=word,
|
||||
probability=0.95
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
return timestamped_words
|
||||
Reference in New Issue
Block a user