Files

65 lines
2.5 KiB
Python

import torch
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
def load_cif(cfg, n_audio_state, device):
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
cif_linear = torch.nn.Linear(n_audio_state, 1)
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
never_fire = True
always_fire = False
else:
always_fire = True
never_fire = False
else:
always_fire = False
never_fire = cfg.never_fire
checkpoint = torch.load(cfg.cif_ckpt_path)
cif_linear.load_state_dict(checkpoint)
cif_linear.to(device)
return cif_linear, always_fire, never_fire
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
def resize(alphas, target_lengths, threshold=0.999):
"""
alpha in thresh=1.0 | (0.0, +0.21)
target_lengths: if None, apply round and resize, else apply scaling
"""
# sum
_num = alphas.sum(-1)
num = target_lengths.float()
# scaling
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
# rm attention value that exceeds threashold
count = 0
while len(torch.where(_alphas > threshold)[0]):
count += 1
if count > 10:
break
xs, ys = torch.where(_alphas > threshold)
for x, y in zip(xs, ys):
if _alphas[x][y] >= threshold:
mask = _alphas[x].ne(0).float()
mean = 0.5 * _alphas[x].sum() / mask.sum()
_alphas[x] = _alphas[x] * 0.5 + mean * mask
return _alphas, _num
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
alphas = torch.sigmoid(alphas)
decode_length = torch.round(alphas.sum(-1)).int()
alphas, _ = resize(alphas, decode_length)
alphas = alphas.squeeze(0) # (T, )
threshold = 0.999
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
exceed_count = integrate[-1] // threshold
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
if important_positions.numel() == 0:
return False
else:
return important_positions[0] >= content_mel_len-2