modify import paths in simul whisper backend so that it works in lib mode

This commit is contained in:
Quentin Fuxa
2025-07-01 20:34:47 +02:00
parent 62bf28949e
commit e1d4bf7e94
10 changed files with 752 additions and 2 deletions

View File

@@ -35,7 +35,7 @@ WhisperLiveKit consists of three main components:
- **Confidence Validation** Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
- **Buffering Preview** Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
- **Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
- **SimulStreaming Backend** - Ultra-low latency transcription using state-of-the-art AlignAtt policy. The code is not directly included in the repo : To use, please copy [simul_whisper](https://github.com/ufal/SimulStreaming/tree/main/simul_whisper) content into `whisperlivekit/simul_whisper` . ⚠️ You must comply with the [Polyform license](https://github.com/ufal/SimulStreaming/blob/main/LICENCE.txt)
- **SimulStreaming Backend** - Ultra-low latency transcription using state-of-the-art AlignAtt policy. The code is not directly included in the repo : To use, please copy [whisper](https://github.com/ufal/SimulStreaming/tree/main/simul_whisper/whisper) folder into `whisperlivekit/simul_whisper/` . ⚠️ You must comply with the [Polyform license](https://github.com/ufal/SimulStreaming/blob/main/LICENCE.txt)
## Quick Start

View File

@@ -28,6 +28,7 @@ setup(
"torch",
"tqdm",
"tiktoken",
"numpy<2.0.0",
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
],
},

View File

View File

@@ -0,0 +1,17 @@
from .whisper.decoding import PyTorchInference
# extention of PyTorchInference for beam search
class BeamPyTorchInference(PyTorchInference):
def _kv_modules(self):
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
return key_modules + value_modules
def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))):
for module_cache_id in self._kv_modules():
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
from torch import Tensor
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)

View File

@@ -0,0 +1,29 @@
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class SimulWhisperConfig:
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str
language: str = field(default="zh")
nonspeech_prob: float = 1.0
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)
@dataclass
class AlignAttConfig(SimulWhisperConfig):
'''Options specific to the AlignAtt policy.'''
eval_data_path: str = "tmp"
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4
rewind_threshold: int = 200
audio_max_len: float = 30.0
cif_ckpt_path: str = ""
never_fire: bool = False

View File

@@ -0,0 +1,65 @@
import torch
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
def load_cif(cfg, n_audio_state, device):
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
cif_linear = torch.nn.Linear(n_audio_state, 1)
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
never_fire = True
always_fire = False
else:
always_fire = True
never_fire = False
else:
always_fire = False
never_fire = cfg.never_fire
checkpoint = torch.load(cfg.cif_ckpt_path)
cif_linear.load_state_dict(checkpoint)
cif_linear.to(device)
return cif_linear, always_fire, never_fire
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
def resize(alphas, target_lengths, threshold=0.999):
"""
alpha in thresh=1.0 | (0.0, +0.21)
target_lengths: if None, apply round and resize, else apply scaling
"""
# sum
_num = alphas.sum(-1)
num = target_lengths.float()
# scaling
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
# rm attention value that exceeds threashold
count = 0
while len(torch.where(_alphas > threshold)[0]):
count += 1
if count > 10:
break
xs, ys = torch.where(_alphas > threshold)
for x, y in zip(xs, ys):
if _alphas[x][y] >= threshold:
mask = _alphas[x].ne(0).float()
mean = 0.5 * _alphas[x].sum() / mask.sum()
_alphas[x] = _alphas[x] * 0.5 + mean * mask
return _alphas, _num
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
alphas = torch.sigmoid(alphas)
decode_length = torch.round(alphas.sum(-1)).int()
alphas, _ = resize(alphas, decode_length)
alphas = alphas.squeeze(0) # (T, )
threshold = 0.999
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
exceed_count = integrate[-1] // threshold
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
if important_positions.numel() == 0:
return False
else:
return important_positions[0] >= content_mel_len-2

View File

@@ -0,0 +1,40 @@
class Tokens:
def __init__(self, tokens):
self.tokens = tokens
# def clone(self):
# return Tokens(self.tokens.clone())
def __str__(self):
return str(self.tokens.tolist())
def __repr__(self):
return self.__str__()
class BeamTokens(Tokens):
def __init__(self, tokens, beam_size):
self.tokens = tokens
self.beam_size = beam_size
def clone(self):
return BeamTokens(self.tokens.clone())
def __str__(self):
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
def __repr__(self):
return self.__str__()
class Logits(Tokens):
def __init__(self, logits):
super().__init__(logits)
# def clone(self):
# return Logits(self.tokens.clone(), self.beam_size)
def __str__(self):
# return "abc"
return f"Logits({self.tokens.shape})"
def __repr__(self):
return self.__str__()

View File

@@ -0,0 +1,525 @@
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
import os
import logging
import torch
import torch.nn.functional as F
from .whisper import load_model, DecodingOptions, tokenizer
from .config import AlignAttConfig
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
from .whisper.timing import median_filter
from .whisper.decoding import SuppressBlank, GreedyDecoder, BeamSearchDecoder, SuppressTokens
from .beam import BeamPyTorchInference
from .eow_detection import fire_at_boundary, load_cif
import os
from whisperlivekit.simul_whisper.token_buffer import TokenBuffer
import numpy as np
from .generation_progress import *
DEC_PAD = 50257
logger = logging.getLogger(__name__)
import sys
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
# - translation support
# - beam search
# - prompt -- static vs. non-static
# - context
class PaddedAlignAttWhisper:
def __init__(self, cfg: AlignAttConfig) -> None:
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
self.model = load_model(name=model_name, download_root=model_path)
logger.info(f"Model dimensions: {self.model.dims}")
decode_options = DecodingOptions(
language = cfg.language,
without_timestamps = True,
task=cfg.task
)
self.tokenizer = tokenizer.get_tokenizer(
multilingual=not model_name.endswith(".en"),
language=cfg.language,
num_languages=self.model.num_languages,
task=decode_options.task
)
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
self.cfg = cfg
# model to detect end-of-word boundary at the end of the segment
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
n_audio_state=self.model.dims.n_audio_state,
device=self.model.device)
# install hooks to access encoder-decoder attention
self.dec_attns = []
def layer_hook(module, net_input, net_output):
# net_output[1]: B*num_head*token_len*audio_len
t = F.softmax(net_output[1], dim=-1)
self.dec_attns.append(t.squeeze(0))
for b in self.model.decoder.blocks:
b.cross_attn.register_forward_hook(layer_hook)
self.kv_cache = {}
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
# save as-is, for the first token or cross attention
self.kv_cache[module.cache_id] = net_output
else:
x = self.kv_cache[module.cache_id]
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
return self.kv_cache[module.cache_id]
for i,b in enumerate(self.model.decoder.blocks):
b.attn.key.register_forward_hook(kv_hook)
b.attn.value.register_forward_hook(kv_hook)
b.cross_attn.key.register_forward_hook(kv_hook)
b.cross_attn.value.register_forward_hook(kv_hook)
self.align_source = {}
self.num_align_heads = 0
for layer_rank, head_id in self.model.alignment_heads.indices().T:
layer_rank = layer_rank.item()
heads = self.align_source.get(layer_rank, [])
heads.append((self.num_align_heads, head_id.item()))
self.align_source[layer_rank] = heads
self.num_align_heads += 1
# init tokens (mandatory prompt)
self.initial_tokens = torch.tensor(
self.tokenizer.sot_sequence_including_notimestamps,
dtype=torch.long,
device=self.model.device).unsqueeze(0)
self.initial_token_length = self.initial_tokens.shape[1]
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
# tokens to be suppressed from decoding, to prevent hallucinations
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
# self.tokenizer.eot
self.tokenizer.no_timestamps, # added by DM
] + list(self.tokenizer.all_language_tokens) # added by DM
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {suppress_tokens}")
sup_tokens = SuppressTokens(suppress_tokens)
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
# blank tokens are suppresed for new segments near the line 334
# decoder type: greedy or beam
if cfg.decoder_type == "greedy":
logger.info("Using greedy decoder")
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
self.decoder_type = "greedy"
elif cfg.decoder_type == "beam":
self.decoder_type = "beam"
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
self.inference.kv_cache = self.kv_cache
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
# init state
self.segments = []
self.tokens = [self.initial_tokens]
self.last_attend_frame = -self.cfg.rewind_threshold
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
self.init_context()
def init_context(self):
kw = {'tokenizer': self.tokenizer,
'device': self.model.device,
'prefix_token_ids': [self.tokenizer.sot_prev]}
self.context = TokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.context.text += self.cfg.init_prompt
def trim_context(self):
logger.info("Trimming context")
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
logger.info(f"Context text: {self.context.as_text()}")
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
l = sum(t.shape[1] for t in self.tokens) + c
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
logger.info(f"Context after trim: {self.context.text} (len: {l})")
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
if self.cfg.decoder_type == "greedy":
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
else:
logger.debug(f"Logits shape: {tokens.shape}")
logit = self.inference.logits(tokens, audio_features)
return logit
def refresh_segment(self, complete=False):
logger.debug("Refreshing segment")
self.tokens = [self.initial_tokens]
self.last_attend_frame = -self.cfg.rewind_threshold
self.init_context()
logger.debug(f"Context: {self.context}")
if not complete and len(self.segments) > 2:
self.segments = self.segments[-2:]
else:
self.segments = []
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
if self.always_fire: return True
if self.never_fire: return False
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
def _current_tokens(self):
toks = self.tokens
# very first infer: duplicate start of seq to beam_size
if toks[0].shape[0] == 1:
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
if not self.context.is_empty():
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
toks = [context_toks] + toks
# make it one tensor
if len(toks) > 1:
current_tokens = torch.cat(toks, dim=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens):
for i in range(self.cfg.beam_size):
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
### audio buffer
def segments_len(self):
segments_len = sum(s.shape[0] for s in self.segments) / 16000
return segments_len
def _apply_minseglen(self):
segments_len = self.segments_len()
# wait for long enough audio to start
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment=None):
if segment is not None:
self.segments.append(segment)
removed_len = 0
# len of audio is bigger than buffer_len. Going to remove the first segment
segments_len = self.segments_len()
while segments_len > self.cfg.audio_max_len:
removed_len = self.segments[0].shape[0] / 16000
segments_len -= removed_len
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
self.segments = self.segments[1:]
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
self.context.append_token_ids(self.tokens[1][0,:])
self.tokens = [self.initial_tokens] + self.tokens[2:]
return removed_len
### transcription / translation
@torch.no_grad()
def infer(self, is_last=False):
new_segment = True
if len(self.segments) == 0:
return []
if not self._apply_minseglen():
return []
# input_segments is concatenation of audio, it's one array
if len(self.segments) > 1:
input_segments = torch.cat(self.segments, dim=0)
else:
input_segments = self.segments[0]
self.trim_context()
current_tokens = self._current_tokens()
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
device=self.model.device).unsqueeze(0)
# trim to 3000
mel = pad_or_trim(mel_padded, N_FRAMES)
# the len of actual audio
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
completed = False
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
####################### Decoding loop
logger.info("Decoding loop starts\n")
attn_of_alignment_heads = None
miost_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
generation_progress = []
generation = {
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
"token_len_before_decoding": token_len_before_decoding,
#"fire_detected": fire_detected,
"frames_len": content_mel_len,
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
# to be filled later
"logits_starting": None,
# to be filled later
"no_speech_prob": None,
"no_speech": False,
# to be filled in the loop
"progress": generation_progress,
}
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
generation_progress_loop = []
if new_segment:
tokens_for_logits = current_tokens
else:
# only need to use the last token except in the first forward pass
tokens_for_logits = current_tokens[:,-1:]
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
if new_segment:
generation["logits_starting"] = Logits(logits[:,:,:])
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
generation["no_speech_prob"] = no_speech_probs[0]
if no_speech_probs[0] > self.cfg.nonspeech_prob:
generation["no_speech"] = True
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # logits for the last token
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
# supress blank tokens only at the beginning of the segment
if new_segment:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
new_segment = False
self.suppress_tokens(logits)
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
generation_progress_loop.append(("completed",completed))
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
self.debug_print_tokens(current_tokens)
# if self.decoder_type == "beam":
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
# logprobs = F.log_softmax(logits.float(), dim=-1)
# idx = 0
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
# if completed:
# self.debug_print_tokens(current_tokens)
# logger.debug("decode stopped because decoder completed")
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
for i, attn_mat in enumerate(self.dec_attns):
layer_rank = int(i % len(self.model.decoder.blocks))
align_heads_in_layer = self.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
a = attn_mat[head_id, :, :]
a = a.unsqueeze(0)
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
t = torch.cat(mat, dim=1)
tmp.append(t)
attn_of_alignment_heads = torch.stack(tmp, dim=1)
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
# for each beam, the most attended frame is:
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
most_attended_frame = most_attended_frames[0].item()
generation_progress.append(dict(generation_progress_loop))
logger.debug("current tokens" + str(current_tokens.shape))
if completed:
# # stripping the last token, the eot
current_tokens = current_tokens[:, :-1]
break
# for some rare cases where the attention fails
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
# TODO: check this
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
logger.debug("ommit rewinding from special tokens")
self.last_attend_frame = most_attended_frame
else:
logger.debug(
f"[rewind detected] current attention pos: {most_attended_frame}, "
f"last attention pos: {self.last_attend_frame}; omit this segment")
self.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
break
else:
self.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
# stripping the last token, the one that is attended too close to the end
current_tokens = current_tokens[:, :-1]
break
# debug print
for i in range(self.cfg.beam_size):
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
most_attended_frames[i],
current_tokens[i, -1].item(),
self.tokenizer.decode([current_tokens[i, -1].item()])
))
# for k,v in generation.items():
# print(k,v,file=sys.stderr)
# for x in generation_progress:
# for y in x.items():
# print("\t\t",*y,file=sys.stderr)
# print("\t","----", file=sys.stderr)
# print("\t", "end of generation_progress_loop", file=sys.stderr)
# sys.exit(1)
####################### End of decoding loop
logger.info("End of decoding loop")
# if attn_of_alignment_heads is not None:
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
# # Lets' now consider only the top hypothesis in the beam search
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
# # debug print: how is the new token attended?
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
# logger.debug("no token generated")
# else: # it is, and the max attention is:
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
# let's now operate only with the top beam hypothesis
tokens_to_split = current_tokens[0, token_len_before_decoding:]
if fire_detected or is_last:
new_hypothesis = tokens_to_split.flatten().tolist()
else:
# going to truncate the tokens after the last space
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
# text_to_split = self.tokenizer.decode(tokens_to_split)
# logger.debug(f"text_to_split: {text_to_split}")
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
### new hypothesis
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
device=self.model.device,
)
self.tokens.append(new_tokens)
# TODO: test if this is redundant or not
# ret = ret[ret<DEC_PAD]
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
# cleaning cache
self.dec_attns = []
self.kv_cache = {}
if self.decoder_type == "beam":
self.inference.kv_cache = self.kv_cache
self.token_decoder.reset()
return new_hypothesis, generation

View File

@@ -0,0 +1,73 @@
import torch
import sys
class TokenBuffer:
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
self.text = text
self.prefix_token_ids = prefix_token_ids
self.tokenizer = tokenizer
self.device = device
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_tensor(self, device=None):
if device is None:
device = self.device
if device is None:
raise ValueError("Device is not set.")
tok_ids = self.as_token_ids()
return torch.tensor(tok_ids,
dtype=torch.long, device=device).unsqueeze(0)
def as_tensor_beam(self, beam, device=None):
t = self.as_tensor(device=device)
return t.repeat_interleave(beam, dim=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return TokenBuffer(*a,**kw)
@staticmethod
def from_text(text, *a, **kw):
return TokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
'''
num: how many words to trim from the beginning
after: how many characters to skip (length of the static prompt)
'''
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
# print(words, file=sys.stderr)
# print(wids, file=sys.stderr)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
self.text += self.tokenizer.decode(token_ids)
def as_split_word_tokens(self):
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text)
return tokenizer.split_to_word_tokens(ids)

View File

@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
# simulStreaming imports - we check if the files are here
try:
import torch
from simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.config import AlignAttConfig
SIMULSTREAMING_AVAILABLE = True
except ImportError:
logger.warning("SimulStreaming dependencies not available for online processor.")