mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
import torch
|
|
|
|
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
|
|
|
|
def load_cif(cfg, n_audio_state, device):
|
|
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
|
|
cif_linear = torch.nn.Linear(n_audio_state, 1)
|
|
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
|
if cfg.never_fire:
|
|
never_fire = True
|
|
always_fire = False
|
|
else:
|
|
always_fire = True
|
|
never_fire = False
|
|
else:
|
|
always_fire = False
|
|
never_fire = cfg.never_fire
|
|
checkpoint = torch.load(cfg.cif_ckpt_path)
|
|
cif_linear.load_state_dict(checkpoint)
|
|
cif_linear.to(device)
|
|
return cif_linear, always_fire, never_fire
|
|
|
|
|
|
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
|
|
def resize(alphas, target_lengths, threshold=0.999):
|
|
"""
|
|
alpha in thresh=1.0 | (0.0, +0.21)
|
|
target_lengths: if None, apply round and resize, else apply scaling
|
|
"""
|
|
# sum
|
|
_num = alphas.sum(-1)
|
|
num = target_lengths.float()
|
|
# scaling
|
|
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
|
|
# rm attention value that exceeds threashold
|
|
count = 0
|
|
while len(torch.where(_alphas > threshold)[0]):
|
|
count += 1
|
|
if count > 10:
|
|
break
|
|
xs, ys = torch.where(_alphas > threshold)
|
|
for x, y in zip(xs, ys):
|
|
if _alphas[x][y] >= threshold:
|
|
mask = _alphas[x].ne(0).float()
|
|
mean = 0.5 * _alphas[x].sum() / mask.sum()
|
|
_alphas[x] = _alphas[x] * 0.5 + mean * mask
|
|
|
|
return _alphas, _num
|
|
|
|
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
|
|
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
|
|
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
|
|
alphas = torch.sigmoid(alphas)
|
|
decode_length = torch.round(alphas.sum(-1)).int()
|
|
alphas, _ = resize(alphas, decode_length)
|
|
alphas = alphas.squeeze(0) # (T, )
|
|
threshold = 0.999
|
|
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
|
|
exceed_count = integrate[-1] // threshold
|
|
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
|
|
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
|
|
if important_positions.numel() == 0:
|
|
return False
|
|
else:
|
|
return important_positions[0] >= content_mel_len-2 |