mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Add a script to detect alignement heads, usefull for distilled whisper
This commit is contained in:
292
scripts/determine_alignment_heads.py
Normal file
292
scripts/determine_alignment_heads.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Determine alignment heads for a variants, such as distilled model"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import gzip
|
||||
import io
|
||||
import pathlib
|
||||
import sys
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio, load_dataset
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
sys.path.insert(0, str(WHISPER_ROOT))
|
||||
|
||||
from whisper import load_model
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||
|
||||
|
||||
def load_dataset_clips(name, config, split, limit):
|
||||
ds = load_dataset(name, config, split=split)
|
||||
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||
clips = []
|
||||
for idx, row in enumerate(ds):
|
||||
if limit is not None and idx >= limit:
|
||||
break
|
||||
audio_field = row["audio"]
|
||||
transcript = row["text"]
|
||||
|
||||
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||
if waveform_np.ndim > 1:
|
||||
waveform_np = waveform_np.mean(axis=1)
|
||||
waveform = waveform_np
|
||||
transcript = str(transcript)
|
||||
|
||||
clips.append((waveform, transcript))
|
||||
return clips
|
||||
|
||||
|
||||
def load_clips(args):
|
||||
return load_dataset_clips(
|
||||
args.dataset,
|
||||
args.dataset_config,
|
||||
args.dataset_split,
|
||||
args.dataset_num_samples,
|
||||
)
|
||||
|
||||
|
||||
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||
return waveform
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Torch device to run on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="librispeech_asr"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config",
|
||||
type=str,
|
||||
default="clean"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="validation[:1%]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-num-samples",
|
||||
type=int,
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Z score threshold for a head to be selected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--votes",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="percentage of clips that must vote for a head",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="alignment_heads.b85",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--visualize-top-k",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def collect_heads(
|
||||
model,
|
||||
tokenizer,
|
||||
clips: Sequence[Tuple[AudioInput, str]],
|
||||
threshold: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = model.device
|
||||
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||
strengths = torch.zeros_like(votes)
|
||||
|
||||
for audio_source, transcript in clips:
|
||||
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||
mel = log_mel_spectrogram(waveform, device=device)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
qks = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
for layer_idx, tensor in enumerate(qks):
|
||||
if tensor is None:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1)
|
||||
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||
strengths[layer_idx] += peak.mean(dim=-1)
|
||||
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||
)
|
||||
mask = (zscore > 3).any(dim=-1)
|
||||
votes[layer_idx] += mask.float()
|
||||
|
||||
votes /= len(clips)
|
||||
strengths /= len(clips)
|
||||
return votes, strengths
|
||||
|
||||
|
||||
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||
selected = torch.nonzero(selection, as_tuple=False)
|
||||
if selected.numel() == 0:
|
||||
return []
|
||||
|
||||
entries = [
|
||||
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||
for layer, head in selected
|
||||
]
|
||||
entries.sort(key=lambda item: item[2], reverse=True)
|
||||
return entries[:top_k]
|
||||
|
||||
def _extract_heatmaps(
|
||||
model,
|
||||
tokenizer,
|
||||
clip: Tuple[AudioInput, str],
|
||||
heads: Sequence[Tuple[int, int, float]],
|
||||
) -> dict:
|
||||
if not heads:
|
||||
return {}
|
||||
|
||||
target_map = {}
|
||||
for layer, head, _ in heads:
|
||||
target_map.setdefault(layer, set()).add(head)
|
||||
|
||||
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||
transcript = clip[1]
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
heatmaps = {}
|
||||
for layer_idx, tensor in enumerate(QKs):
|
||||
if tensor is None or layer_idx not in target_map:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1).cpu()
|
||||
for head_idx in target_map[layer_idx]:
|
||||
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
def _plot_heatmaps(
|
||||
heads, heatmaps, output_path):
|
||||
cols = min(3, len(heads))
|
||||
rows = math.ceil(len(heads) / cols)
|
||||
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||
|
||||
for idx, (layer, head, score) in enumerate(heads):
|
||||
ax = axes[idx // cols][idx % cols]
|
||||
mat = heatmaps.get((layer, head))
|
||||
if mat is None:
|
||||
ax.axis("off")
|
||||
continue
|
||||
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||
ax.set_xlabel("time")
|
||||
ax.set_ylabel("tokens")
|
||||
|
||||
for j in range(len(heads), rows * cols):
|
||||
axes[j // cols][j % cols].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||
payload = mask.numpy().astype(np.bool_)
|
||||
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(blob)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
model = load_model(args.model, device=args.device)
|
||||
model.eval()
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||
clips = load_clips(args)
|
||||
|
||||
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||
# selection = votes > 0.5
|
||||
selection = strengths > 0.05
|
||||
_dump_mask(selection.cpu(), args.output)
|
||||
|
||||
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user