"""Determine alignment heads for a variants, such as distilled model""" from __future__ import annotations import argparse import base64 import gzip import io import pathlib import sys import math from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch from datasets import Audio as DatasetAudio, load_dataset import soundfile as sf import matplotlib.pyplot as plt REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] WHISPER_ROOT = REPO_ROOT / "whisper" sys.path.insert(0, str(REPO_ROOT)) sys.path.insert(0, str(WHISPER_ROOT)) from whisper import load_model from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim from whisper.tokenizer import get_tokenizer AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor] def load_dataset_clips(name, config, split, limit): ds = load_dataset(name, config, split=split) ds = ds.cast_column("audio", DatasetAudio(decode=False)) clips = [] for idx, row in enumerate(ds): if limit is not None and idx >= limit: break audio_field = row["audio"] transcript = row["text"] waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32") if waveform_np.ndim > 1: waveform_np = waveform_np.mean(axis=1) waveform = waveform_np transcript = str(transcript) clips.append((waveform, transcript)) return clips def load_clips(args): return load_dataset_clips( args.dataset, args.dataset_config, args.dataset_split, args.dataset_num_samples, ) def _waveform_from_source(source: AudioInput) -> torch.Tensor: waveform = torch.from_numpy(source.astype(np.float32, copy=False)) return waveform def _parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, default="pytorch_model.bin", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Torch device to run on", ) parser.add_argument( "--dataset", type=str, default="librispeech_asr" ) parser.add_argument( "--dataset-config", type=str, default="clean" ) parser.add_argument( "--dataset-split", type=str, default="validation[:1%]", ) parser.add_argument( "--dataset-num-samples", type=int, default=16, ) parser.add_argument( "--threshold", type=float, default=1.5, help="Z score threshold for a head to be selected", ) parser.add_argument( "--votes", type=float, default=0.75, help="percentage of clips that must vote for a head", ) parser.add_argument( "--output", type=str, default="alignment_heads.b85", ) parser.add_argument( "--visualize-top-k", type=int, default=32, ) return parser.parse_args() def collect_heads( model, tokenizer, clips: Sequence[Tuple[AudioInput, str]], threshold: float, ) -> Tuple[torch.Tensor, torch.Tensor]: device = model.device votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device) strengths = torch.zeros_like(votes) for audio_source, transcript in clips: waveform = pad_or_trim(_waveform_from_source(audio_source)) mel = log_mel_spectrogram(waveform, device=device) tokens = torch.tensor( [ *tokenizer.sot_sequence, tokenizer.no_timestamps, *tokenizer.encode(transcript), tokenizer.eot, ], device=device, ) qks = [None] * model.dims.n_text_layer hooks = [ block.cross_attn.register_forward_hook( lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0]) ) for i, block in enumerate(model.decoder.blocks) ] with torch.no_grad(): model(mel.unsqueeze(0), tokens.unsqueeze(0)) for hook in hooks: hook.remove() for layer_idx, tensor in enumerate(qks): if tensor is None: continue tensor = tensor[:, :, : mel.shape[-1] // 2] tensor = tensor.softmax(dim=-1) peak = tensor.max(dim=-1).values # [heads, tokens] strengths[layer_idx] += peak.mean(dim=-1) zscore = (peak - peak.mean(dim=-1, keepdim=True)) / ( peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6 ) mask = (zscore > 3).any(dim=-1) votes[layer_idx] += mask.float() votes /= len(clips) strengths /= len(clips) return votes, strengths def _select_heads_for_visualization(selection, strengths, top_k): selected = torch.nonzero(selection, as_tuple=False) if selected.numel() == 0: return [] entries = [ (int(layer.item()), int(head.item()), float(strengths[layer, head].item())) for layer, head in selected ] entries.sort(key=lambda item: item[2], reverse=True) return entries[:top_k] def _extract_heatmaps( model, tokenizer, clip: Tuple[AudioInput, str], heads: Sequence[Tuple[int, int, float]], ) -> dict: if not heads: return {} target_map = {} for layer, head, _ in heads: target_map.setdefault(layer, set()).add(head) waveform = pad_or_trim(_waveform_from_source(clip[0])) mel = log_mel_spectrogram(waveform, device=model.device) transcript = clip[1] tokens = torch.tensor( [ *tokenizer.sot_sequence, tokenizer.no_timestamps, *tokenizer.encode(transcript), tokenizer.eot, ], device=model.device, ) QKs = [None] * model.dims.n_text_layer hooks = [ block.cross_attn.register_forward_hook( lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0]) ) for i, block in enumerate(model.decoder.blocks) ] with torch.no_grad(): model(mel.unsqueeze(0), tokens.unsqueeze(0)) for hook in hooks: hook.remove() heatmaps = {} for layer_idx, tensor in enumerate(QKs): if tensor is None or layer_idx not in target_map: continue tensor = tensor[:, :, : mel.shape[-1] // 2] tensor = tensor.softmax(dim=-1).cpu() for head_idx in target_map[layer_idx]: heatmaps[(layer_idx, head_idx)] = tensor[head_idx] return heatmaps def _plot_heatmaps( heads, heatmaps, output_path): cols = min(3, len(heads)) rows = math.ceil(len(heads) / cols) fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False) for idx, (layer, head, score) in enumerate(heads): ax = axes[idx // cols][idx % cols] mat = heatmaps.get((layer, head)) if mat is None: ax.axis("off") continue im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower") ax.set_title(f"L{layer} H{head} ยท score {score:.2f}") ax.set_xlabel("time") ax.set_ylabel("tokens") for j in range(len(heads), rows * cols): axes[j // cols][j % cols].axis("off") fig.tight_layout() fig.savefig(output_path, dpi=200) plt.close(fig) def _dump_mask(mask: torch.Tensor, output_path: str): payload = mask.numpy().astype(np.bool_) blob = base64.b85encode(gzip.compress(payload.tobytes())) with open(output_path, "wb") as f: f.write(blob) def main(): args = _parse_args() model = load_model(args.model, device=args.device) model.eval() tokenizer = get_tokenizer(multilingual=model.is_multilingual) clips = load_clips(args) votes, strengths = collect_heads(model, tokenizer, clips, args.threshold) # selection = votes > 0.5 selection = strengths > 0.05 _dump_mask(selection.cpu(), args.output) viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k) heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads) _plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png") if __name__ == "__main__": main()