diff --git a/README.md b/README.md index 2fbd1da..451c82f 100644 --- a/README.md +++ b/README.md @@ -171,7 +171,8 @@ async def websocket_endpoint(websocket: WebSocket): | SimulStreaming backend options | Description | Default | |-----------|-------------|---------| | `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` | -| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` | +| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. WhisperLiveKit Demo + | `None` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | diff --git a/scripts/alignment_heads.png b/scripts/alignment_heads.png new file mode 100644 index 0000000..f276fc7 Binary files /dev/null and b/scripts/alignment_heads.png differ diff --git a/scripts/determine_alignment_heads.py b/scripts/determine_alignment_heads.py new file mode 100644 index 0000000..cd5344f --- /dev/null +++ b/scripts/determine_alignment_heads.py @@ -0,0 +1,292 @@ +"""Determine alignment heads for a variants, such as distilled model""" +from __future__ import annotations + +import argparse +import base64 +import gzip +import io +import pathlib +import sys +import math +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from datasets import Audio as DatasetAudio, load_dataset +import soundfile as sf +import matplotlib.pyplot as plt +REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +WHISPER_ROOT = REPO_ROOT / "whisper" + +sys.path.insert(0, str(REPO_ROOT)) +sys.path.insert(0, str(WHISPER_ROOT)) + +from whisper import load_model +from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim +from whisper.tokenizer import get_tokenizer + +AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor] + + +def load_dataset_clips(name, config, split, limit): + ds = load_dataset(name, config, split=split) + ds = ds.cast_column("audio", DatasetAudio(decode=False)) + clips = [] + for idx, row in enumerate(ds): + if limit is not None and idx >= limit: + break + audio_field = row["audio"] + transcript = row["text"] + + waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32") + if waveform_np.ndim > 1: + waveform_np = waveform_np.mean(axis=1) + waveform = waveform_np + transcript = str(transcript) + + clips.append((waveform, transcript)) + return clips + + +def load_clips(args): + return load_dataset_clips( + args.dataset, + args.dataset_config, + args.dataset_split, + args.dataset_num_samples, + ) + + +def _waveform_from_source(source: AudioInput) -> torch.Tensor: + waveform = torch.from_numpy(source.astype(np.float32, copy=False)) + return waveform + + +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="pytorch_model.bin", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Torch device to run on", + ) + parser.add_argument( + "--dataset", + type=str, + default="librispeech_asr" + ) + parser.add_argument( + "--dataset-config", + type=str, + default="clean" + ) + parser.add_argument( + "--dataset-split", + type=str, + default="validation[:1%]", + ) + parser.add_argument( + "--dataset-num-samples", + type=int, + default=16, + ) + parser.add_argument( + "--threshold", + type=float, + default=1.5, + help="Z score threshold for a head to be selected", + ) + parser.add_argument( + "--votes", + type=float, + default=0.75, + help="percentage of clips that must vote for a head", + ) + parser.add_argument( + "--output", + type=str, + default="alignment_heads.b85", + ) + parser.add_argument( + "--visualize-top-k", + type=int, + default=32, + ) + return parser.parse_args() + + +def collect_heads( + model, + tokenizer, + clips: Sequence[Tuple[AudioInput, str]], + threshold: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + device = model.device + votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device) + strengths = torch.zeros_like(votes) + + for audio_source, transcript in clips: + waveform = pad_or_trim(_waveform_from_source(audio_source)) + mel = log_mel_spectrogram(waveform, device=device) + + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.no_timestamps, + *tokenizer.encode(transcript), + tokenizer.eot, + ], + device=device, + ) + + qks = [None] * model.dims.n_text_layer + hooks = [ + block.cross_attn.register_forward_hook( + lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0]) + ) + for i, block in enumerate(model.decoder.blocks) + ] + + with torch.no_grad(): + model(mel.unsqueeze(0), tokens.unsqueeze(0)) + + for hook in hooks: + hook.remove() + + for layer_idx, tensor in enumerate(qks): + if tensor is None: + continue + tensor = tensor[:, :, : mel.shape[-1] // 2] + tensor = tensor.softmax(dim=-1) + peak = tensor.max(dim=-1).values # [heads, tokens] + strengths[layer_idx] += peak.mean(dim=-1) + zscore = (peak - peak.mean(dim=-1, keepdim=True)) / ( + peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6 + ) + mask = (zscore > 3).any(dim=-1) + votes[layer_idx] += mask.float() + + votes /= len(clips) + strengths /= len(clips) + return votes, strengths + + +def _select_heads_for_visualization(selection, strengths, top_k): + selected = torch.nonzero(selection, as_tuple=False) + if selected.numel() == 0: + return [] + + entries = [ + (int(layer.item()), int(head.item()), float(strengths[layer, head].item())) + for layer, head in selected + ] + entries.sort(key=lambda item: item[2], reverse=True) + return entries[:top_k] + +def _extract_heatmaps( + model, + tokenizer, + clip: Tuple[AudioInput, str], + heads: Sequence[Tuple[int, int, float]], +) -> dict: + if not heads: + return {} + + target_map = {} + for layer, head, _ in heads: + target_map.setdefault(layer, set()).add(head) + + waveform = pad_or_trim(_waveform_from_source(clip[0])) + mel = log_mel_spectrogram(waveform, device=model.device) + transcript = clip[1] + tokens = torch.tensor( + [ + *tokenizer.sot_sequence, + tokenizer.no_timestamps, + *tokenizer.encode(transcript), + tokenizer.eot, + ], + device=model.device, + ) + + QKs = [None] * model.dims.n_text_layer + hooks = [ + block.cross_attn.register_forward_hook( + lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0]) + ) + for i, block in enumerate(model.decoder.blocks) + ] + + with torch.no_grad(): + model(mel.unsqueeze(0), tokens.unsqueeze(0)) + + for hook in hooks: + hook.remove() + + heatmaps = {} + for layer_idx, tensor in enumerate(QKs): + if tensor is None or layer_idx not in target_map: + continue + tensor = tensor[:, :, : mel.shape[-1] // 2] + tensor = tensor.softmax(dim=-1).cpu() + for head_idx in target_map[layer_idx]: + heatmaps[(layer_idx, head_idx)] = tensor[head_idx] + + return heatmaps + + +def _plot_heatmaps( + heads, heatmaps, output_path): + cols = min(3, len(heads)) + rows = math.ceil(len(heads) / cols) + fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False) + + for idx, (layer, head, score) in enumerate(heads): + ax = axes[idx // cols][idx % cols] + mat = heatmaps.get((layer, head)) + if mat is None: + ax.axis("off") + continue + im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower") + ax.set_title(f"L{layer} H{head} ยท score {score:.2f}") + ax.set_xlabel("time") + ax.set_ylabel("tokens") + + for j in range(len(heads), rows * cols): + axes[j // cols][j % cols].axis("off") + + fig.tight_layout() + fig.savefig(output_path, dpi=200) + plt.close(fig) + + +def _dump_mask(mask: torch.Tensor, output_path: str): + payload = mask.numpy().astype(np.bool_) + blob = base64.b85encode(gzip.compress(payload.tobytes())) + with open(output_path, "wb") as f: + f.write(blob) + + +def main(): + args = _parse_args() + model = load_model(args.model, device=args.device) + model.eval() + tokenizer = get_tokenizer(multilingual=model.is_multilingual) + clips = load_clips(args) + + votes, strengths = collect_heads(model, tokenizer, clips, args.threshold) + # selection = votes > 0.5 + selection = strengths > 0.05 + _dump_mask(selection.cpu(), args.output) + + viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k) + heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads) + _plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png") + +if __name__ == "__main__": + main() diff --git a/sync_extension.py b/scripts/sync_extension.py similarity index 92% rename from sync_extension.py rename to scripts/sync_extension.py index 0ccae60..ea6f8d6 100644 --- a/sync_extension.py +++ b/scripts/sync_extension.py @@ -1,9 +1,10 @@ +"""Copy core files from web directory to Chrome extension directory.""" + import shutil import os from pathlib import Path def sync_extension_files(): - """Copy core files from web directory to Chrome extension directory.""" web_dir = Path("whisperlivekit/web") extension_dir = Path("chrome-extension") diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 3e3b63f..26bce5f 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -195,7 +195,6 @@ class SimulStreamingASR(): 'large-v3': './large-v3.pt', 'large': './large-v3.pt' } - self.pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt')) self.model_name = self.model_size is_multilingual = not self.model_name.endswith(".en")