Files
WhisperLiveKit/scripts/detect_alignment_heads_qwen3.py
2026-03-14 00:12:50 +01:00

704 lines
26 KiB
Python

#!/usr/bin/env python3
"""
Detect alignment heads in Qwen3-ASR for SimulStreaming-style inference.
Qwen3-ASR is a decoder-only multimodal model: audio is encoded by an audio
encoder and the resulting embeddings are injected into the text sequence
(replacing <|audio_pad|> placeholder tokens). The text decoder then attends
over the full sequence -- both audio-derived tokens and text tokens -- via
causal self-attention. There is **no** cross-attention.
For AlignAtt-style streaming, we need to find which (layer, head) pairs in
the text decoder's self-attention best track the monotonic alignment between
generated text tokens and their corresponding audio positions.
Algorithm
---------
For each audio sample with a known transcript:
1. Run Qwen3-ASR with output_attentions=True
2. Use the ForcedAligner to get ground-truth word->timestamp alignments
3. Convert timestamps to audio token positions in the input sequence
4. For each generated text token, check whether the argmax of each
attention head (over the audio-token region) points to the correct
audio position (as determined by the forced aligner)
5. Accumulate scores per (layer, head)
The heads whose attention argmax matches the ground-truth alignment most
often are the "alignment heads" usable for SimulStreaming.
Reference: Adapted from scripts/determine_alignment_heads.py (Whisper) and
iwslt26-sst/SimulMT_tests/heads/detect_translation_heads_qwen3.py
"""
import argparse
import io
import json
import logging
import re
import time
from difflib import SequenceMatcher
from typing import List, Optional, Tuple
import numpy as np
import soundfile as sf
import torch
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ── Compatibility patches for qwen_asr 0.0.6 + transformers >= 5.3 ────
def _apply_transformers_compat_patches():
"""Apply all necessary patches to make qwen_asr work with transformers >= 5.3."""
# 1. check_model_inputs was removed
try:
import transformers.utils.generic as _g
if not hasattr(_g, "check_model_inputs"):
def check_model_inputs(*args, **kwargs):
def decorator(fn):
return fn
return decorator
_g.check_model_inputs = check_model_inputs
except ImportError:
pass
# 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS
try:
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
if "default" not in ROPE_INIT_FUNCTIONS:
def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads
partial = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial)
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, 1.0
ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters
except ImportError:
pass
# 3. pad_token_id missing on thinker config
try:
from qwen_asr.core.transformers_backend.configuration_qwen3_asr import (
Qwen3ASRThinkerConfig,
)
if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"):
Qwen3ASRThinkerConfig.pad_token_id = None
except ImportError:
pass
# 4. fix_mistral_regex is now handled internally by transformers 5.3;
# qwen_asr passes it explicitly, causing a duplicate-kwarg error.
try:
from transformers.models.auto import processing_auto
_orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__
@classmethod
def _patched_ap_from_pretrained(cls, *args, **kwargs):
kwargs.pop("fix_mistral_regex", None)
return _orig_ap_from_pretrained(cls, *args, **kwargs)
processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained
except Exception:
pass
# 5. _finalize_model_loading calls initialize_weights which expects
# compute_default_rope_parameters on RotaryEmbedding modules.
try:
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
Qwen3ASRThinkerTextRotaryEmbedding,
)
if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"):
@staticmethod
def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads
partial = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial)
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, 1.0
Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _compute_default_rope_parameters
except ImportError:
pass
_apply_transformers_compat_patches()
# ── Constants ────────────────────────────────────────────────────────
SAMPLE_RATE = 16000
TS_THRESHOLD = 0.1 # Minimum Translation Score to qualify as alignment head
MIN_TEXT_SIMILARITY = 0.3 # Skip clips where generated text is too different from ground truth
def text_similarity(generated: str, reference: str) -> float:
"""Compute text similarity between generated and reference transcriptions.
Normalizes both strings (lowercase, remove punctuation, collapse whitespace)
then returns SequenceMatcher ratio.
"""
def normalize(s):
s = s.lower()
s = re.sub(r'[^\w\s]', '', s)
return re.sub(r'\s+', ' ', s).strip()
gen_norm = normalize(generated)
ref_norm = normalize(reference)
if not gen_norm or not ref_norm:
return 0.0
return SequenceMatcher(None, gen_norm, ref_norm).ratio()
def load_dataset_clips(name, config, split, limit):
"""Load audio clips from a HuggingFace dataset."""
from datasets import Audio as DatasetAudio
from datasets import load_dataset
ds = load_dataset(name, config, split=split)
ds = ds.cast_column("audio", DatasetAudio(decode=False))
clips = []
for idx, row in enumerate(ds):
if limit is not None and idx >= limit:
break
audio_field = row["audio"]
transcript = row["text"]
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
if waveform_np.ndim > 1:
waveform_np = waveform_np.mean(axis=1)
clips.append((waveform_np, str(transcript)))
return clips
def get_device():
"""Select the best available device."""
if torch.backends.mps.is_available():
logger.info("Using MPS (Apple Silicon GPU)")
return torch.device("mps")
elif torch.cuda.is_available():
logger.info("Using CUDA (%s)", torch.cuda.get_device_name())
return torch.device("cuda")
else:
logger.info("Using CPU (will be slow)")
return torch.device("cpu")
def load_qwen3_asr(model_id: str, device: torch.device, dtype: torch.dtype):
"""Load Qwen3-ASR model, processor, and forced aligner."""
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from qwen_asr.inference.qwen3_forced_aligner import Qwen3ForcedAligner
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
logger.info("Loading model: %s (dtype=%s, device=%s)", model_id, dtype, device)
model = AutoModel.from_pretrained(
model_id,
torch_dtype=dtype,
attn_implementation="eager",
device_map=str(device),
)
model.eval()
# Force eager attention on all sub-modules (attn_implementation="eager" doesn't
# propagate through nested model configs in qwen_asr's custom architecture)
for name, module in model.named_modules():
if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"):
module.config._attn_implementation = "eager"
module.config._attn_implementation_internal = "eager"
try:
processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True)
except TypeError:
processor = AutoProcessor.from_pretrained(model_id)
logger.info("Loading forced aligner: Qwen/Qwen3-ForcedAligner-0.6B")
forced_aligner = Qwen3ForcedAligner.from_pretrained(
"Qwen/Qwen3-ForcedAligner-0.6B",
dtype=dtype,
device_map=str(device),
)
return model, processor, forced_aligner
def find_audio_token_range(input_ids: torch.Tensor, audio_token_id: int) -> Tuple[int, int]:
"""Find the start and end positions of audio tokens in the input sequence."""
mask = (input_ids == audio_token_id)
positions = mask.nonzero(as_tuple=True)[0]
if len(positions) == 0:
return 0, 0
return positions[0].item(), positions[-1].item() + 1
def timestamp_to_audio_token_position(
timestamp_sec: float,
audio_duration_sec: float,
audio_token_start: int,
audio_token_end: int,
) -> int:
"""Convert a timestamp in seconds to the corresponding audio token position.
Audio tokens span [audio_token_start, audio_token_end) in the input sequence.
We linearly interpolate within that range based on the timestamp fraction.
"""
n_audio_tokens = audio_token_end - audio_token_start
if n_audio_tokens <= 0 or audio_duration_sec <= 0:
return audio_token_start
fraction = min(timestamp_sec / audio_duration_sec, 1.0)
pos = audio_token_start + int(fraction * (n_audio_tokens - 1))
return max(audio_token_start, min(pos, audio_token_end - 1))
def run_detection(
model,
processor,
forced_aligner,
clips: List[Tuple[np.ndarray, str]],
language: Optional[str],
device: torch.device,
) -> Tuple[np.ndarray, int]:
"""Run alignment head detection on a set of audio clips.
Uses PyTorch forward hooks on each self_attn module to capture attention
weights that the decoder layer discards (``hidden_states, _ = self.self_attn(...)``).
With eager attention, ``self_attn`` always returns ``(attn_output, attn_weights)``
so the hook can read the weights from the return value.
Returns:
g: array of shape (total_heads,) with alignment hit counts
m: total number of alignment checks performed
"""
thinker = model.thinker
text_config = thinker.config.text_config
num_layers = text_config.num_hidden_layers
num_heads = text_config.num_attention_heads
total_heads = num_layers * num_heads
audio_token_id = thinker.config.audio_token_id
logger.info(
"Text decoder: %d layers x %d heads = %d total heads",
num_layers, num_heads, total_heads,
)
logger.info(
"KV heads: %d (GQA ratio: %d)",
text_config.num_key_value_heads,
num_heads // text_config.num_key_value_heads,
)
# Build prompt helper (same as Qwen3ASRModel._build_text_prompt)
from qwen_asr.inference.utils import normalize_language_name
def build_messages(audio_payload):
return [
{"role": "system", "content": ""},
{"role": "user", "content": [{"type": "audio", "audio": audio_payload}]},
]
def build_text_prompt(force_language=None):
msgs = build_messages("")
base = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
if force_language:
base = base + f"language {force_language}<asr_text>"
return base
force_lang = None
if language:
force_lang = normalize_language_name(language)
# Stop token IDs
eos_ids = {151645, 151643} # <|im_end|>, <|endoftext|>
if processor.tokenizer.eos_token_id is not None:
eos_ids.add(processor.tokenizer.eos_token_id)
# Decoder layers: model.thinker.model.layers[i].self_attn
decoder_layers = thinker.model.layers
g = np.zeros(total_heads, dtype=np.int64)
m = 0
t0 = time.time()
for clip_idx, (waveform, transcript) in enumerate(clips):
if not transcript.strip():
continue
audio_duration = len(waveform) / SAMPLE_RATE
# 1. Get forced alignment timestamps
try:
align_results = forced_aligner.align(
audio=[(waveform, SAMPLE_RATE)],
text=[transcript],
language=[force_lang or "English"],
)
align_result = align_results[0]
except Exception as e:
logger.warning("Forced alignment failed for clip %d: %s", clip_idx, e)
continue
if not align_result.items:
continue
# Build word -> (start_time, end_time) mapping
word_timestamps = []
for item in align_result.items:
word_timestamps.append((item.text, item.start_time, item.end_time))
# 2. Prepare inputs
text_prompt = build_text_prompt(force_language=force_lang)
inputs = processor(
text=[text_prompt],
audio=[waveform],
return_tensors="pt",
padding=True,
)
inputs = inputs.to(model.device).to(model.dtype)
prompt_len = inputs.input_ids.shape[1]
# Find audio token range
audio_start, audio_end = find_audio_token_range(
inputs.input_ids[0], audio_token_id,
)
n_audio_tokens = audio_end - audio_start
if n_audio_tokens == 0:
logger.warning("No audio tokens found in clip %d", clip_idx)
continue
# 3. Register forward hooks on self_attn to capture attention weights.
# The decoder layer discards them: hidden_states, _ = self.self_attn(...)
# but eager_attention_forward always computes and returns attn_weights.
# We capture just the argmax over the audio region (memory-efficient).
# captured_argmax[layer_idx] = list of (num_heads,) tensors, one per decode step.
captured_argmax = {i: [] for i in range(num_layers)}
def _make_hook(store, a_start, a_end):
def hook_fn(module, args, output):
# output = (attn_output, attn_weights)
attn_weights = output[1]
if attn_weights is None:
return
# attn_weights shape: (batch, num_heads, q_len, kv_len)
# Only capture decode steps (q_len == 1), skip prefill
if attn_weights.shape[2] != 1:
return
kv_len = attn_weights.shape[-1]
if a_end > kv_len:
return
# Attention from the new token over audio region
audio_attn = attn_weights[0, :, 0, a_start:a_end] # (num_heads, n_audio)
store.append(audio_attn.argmax(dim=-1).cpu()) # (num_heads,)
return hook_fn
hooks = []
for layer_idx in range(num_layers):
h = decoder_layers[layer_idx].self_attn.register_forward_hook(
_make_hook(captured_argmax[layer_idx], audio_start, audio_end)
)
hooks.append(h)
# 4. Run generation
try:
with torch.inference_mode():
outputs = thinker.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
)
except Exception as e:
for h in hooks:
h.remove()
logger.warning("Generation failed for clip %d: %s", clip_idx, e)
continue
finally:
for h in hooks:
h.remove()
# outputs is (batch, seq_len) tensor
all_generated = outputs[0, prompt_len:]
num_gen = len(all_generated)
for i, tid in enumerate(all_generated):
if tid.item() in eos_ids:
num_gen = i
break
generated_ids = all_generated[:num_gen]
if num_gen == 0:
del outputs, captured_argmax
continue
generated_text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
# Filter out hallucinated clips (e.g. "!!!" patterns)
sim = text_similarity(generated_text, transcript)
if sim < MIN_TEXT_SIMILARITY:
logger.info(
"[%d/%d] SKIP (sim=%.2f) | %s...",
clip_idx + 1, len(clips), sim, generated_text[:60],
)
del outputs, captured_argmax
continue
# Verify hooks captured data
n_captured = len(captured_argmax[0])
if n_captured == 0:
logger.warning(
"No attention weights captured for clip %d (hooks may not have fired)", clip_idx
)
del outputs, captured_argmax
continue
# 5. Map generated tokens to word timestamps
gen_token_strings = [
processor.tokenizer.decode([tid.item()]) for tid in generated_ids
]
# Map each generated token index -> forced-aligner word index
accumulated_text = ""
word_idx = 0
token_to_word = {}
for tok_idx, tok_str in enumerate(gen_token_strings):
accumulated_text += tok_str
# Advance word index when accumulated text covers the current word
while (
word_idx < len(word_timestamps)
and len(accumulated_text.strip()) >= sum(
len(w[0]) + 1 for w in word_timestamps[:word_idx + 1]
)
):
word_idx += 1
actual_word_idx = min(word_idx, len(word_timestamps) - 1)
token_to_word[tok_idx] = actual_word_idx
# 6. Score each head using captured argmax data
for gen_step in range(num_gen):
word_idx = token_to_word.get(gen_step, None)
if word_idx is None or word_idx >= len(word_timestamps):
continue
_, word_start, word_end = word_timestamps[word_idx]
word_mid = (word_start + word_end) / 2.0
# Expected audio token position for this word
expected_pos = timestamp_to_audio_token_position(
word_mid, audio_duration, audio_start, audio_end,
)
# Tolerance: +/- a few audio tokens (proportional to word duration)
word_dur_tokens = max(1, int(
(word_end - word_start) / audio_duration * n_audio_tokens / 2
))
tolerance = max(3, word_dur_tokens)
m += 1
for layer_idx in range(num_layers):
if gen_step >= len(captured_argmax[layer_idx]):
continue
argmaxes = captured_argmax[layer_idx][gen_step].numpy() # (num_heads,)
for head_idx in range(num_heads):
attended_pos = argmaxes[head_idx] # relative to audio_start
attended_abs = audio_start + attended_pos
if abs(attended_abs - expected_pos) <= tolerance:
g[layer_idx * num_heads + head_idx] += 1
del outputs, captured_argmax
if device.type == "mps":
torch.mps.empty_cache()
elif device.type == "cuda":
torch.cuda.empty_cache()
elapsed = time.time() - t0
avg = elapsed / (clip_idx + 1)
eta = avg * (len(clips) - clip_idx - 1)
logger.info(
"[%d/%d] m=%d | %s... | %.1fs/clip | ETA: %.0fs",
clip_idx + 1, len(clips), m,
generated_text[:60], avg, eta,
)
return g, m
def main():
parser = argparse.ArgumentParser(
description="Detect alignment heads in Qwen3-ASR for SimulStreaming"
)
parser.add_argument(
"--model", type=str, default="Qwen/Qwen3-ASR-1.7B",
help="Qwen3-ASR model name or path",
)
parser.add_argument(
"--dataset", type=str, default="librispeech_asr",
help="HuggingFace dataset name",
)
parser.add_argument(
"--dataset-config", type=str, default="clean",
help="Dataset config/subset",
)
parser.add_argument(
"--dataset-split", type=str, default="validation",
help="Dataset split",
)
parser.add_argument(
"-n", "--num-samples", type=int, default=50,
help="Number of audio samples to process",
)
parser.add_argument(
"--language", type=str, default="English",
help="Language for forced alignment",
)
parser.add_argument(
"--dtype", type=str, default="bf16",
choices=["float32", "bf16", "float16"],
help="Model dtype",
)
parser.add_argument(
"-o", "--output", type=str, default="alignment_heads_qwen3_asr.json",
help="Output JSON file",
)
parser.add_argument(
"--heatmap", type=str, default="alignment_heads_qwen3_asr.png",
help="Output heatmap image",
)
parser.add_argument(
"--threshold", type=float, default=TS_THRESHOLD,
help="Minimum alignment score threshold",
)
args = parser.parse_args()
device = get_device()
dtype_map = {
"float32": torch.float32,
"bf16": torch.bfloat16,
"float16": torch.float16,
}
dtype = dtype_map[args.dtype]
# Load model
model, processor, forced_aligner = load_qwen3_asr(args.model, device, dtype)
# Load data
logger.info("Loading dataset: %s/%s [%s]", args.dataset, args.dataset_config, args.dataset_split)
clips = load_dataset_clips(
args.dataset, args.dataset_config, args.dataset_split, args.num_samples,
)
logger.info("Loaded %d clips", len(clips))
# Run detection
g, m = run_detection(model, processor, forced_aligner, clips, args.language, device)
# Compute alignment scores
thinker = model.thinker
text_config = thinker.config.text_config
num_layers = text_config.num_hidden_layers
num_heads = text_config.num_attention_heads
ts = g / max(m, 1)
ts_matrix = ts.reshape(num_layers, num_heads)
# Identify alignment heads
tah = []
for l in range(num_layers):
for h in range(num_heads):
score = ts_matrix[l, h]
if score > args.threshold:
tah.append({"layer": l, "head": h, "ts": round(float(score), 4)})
tah.sort(key=lambda x: x["ts"], reverse=True)
# Print results
print(f"\n{'=' * 60}")
print(f"ALIGNMENT HEADS (TS > {args.threshold}): {len(tah)} / {num_layers * num_heads}")
print(f"{'=' * 60}")
for entry in tah:
bar = "#" * int(entry["ts"] * 50)
print(f" L{entry['layer']:2d} H{entry['head']:2d} : TS={entry['ts']:.4f} {bar}")
n_active = sum(1 for s in ts if s > args.threshold)
n_low = sum(1 for s in ts if 0 < s <= args.threshold)
n_zero = sum(1 for s in ts if s == 0)
total_heads = num_layers * num_heads
print(f"\nDistribution:")
print(f" TS > {args.threshold} (alignment heads): {n_active} ({100 * n_active / total_heads:.1f}%)")
print(f" 0 < TS <= {args.threshold} (low activity): {n_low} ({100 * n_low / total_heads:.1f}%)")
print(f" TS = 0 (inactive): {n_zero} ({100 * n_zero / total_heads:.1f}%)")
print(f"\nTotal alignable tokens checked: m={m}")
# Save JSON
output = {
"model": args.model,
"language": args.language,
"num_layers": num_layers,
"num_heads": num_heads,
"num_kv_heads": text_config.num_key_value_heads,
"num_samples": len(clips),
"total_alignable_tokens": int(m),
"ts_threshold": args.threshold,
"ts_matrix": ts_matrix.tolist(),
"alignment_heads": tah,
# WhisperLiveKit-compatible format: list of [layer, head] pairs
"alignment_heads_compact": [[e["layer"], e["head"]] for e in tah],
}
with open(args.output, "w") as f:
json.dump(output, f, indent=2)
logger.info("Results saved to %s", args.output)
# Generate heatmap
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, ax = plt.subplots(
figsize=(max(10, num_heads * 0.6), max(8, num_layers * 0.35)),
)
im = ax.imshow(
ts_matrix,
aspect="auto",
cmap="RdYlBu_r",
vmin=0,
vmax=max(0.4, ts_matrix.max()),
interpolation="nearest",
)
ax.set_xlabel("Head ID", fontsize=12)
ax.set_ylabel("Layer", fontsize=12)
ax.set_title(
f"Alignment Scores - {args.model}\n"
f"{len(tah)} alignment heads (TS > {args.threshold}), n={len(clips)}",
fontsize=13,
)
ax.set_xticks(range(num_heads))
ax.set_yticks(range(num_layers))
plt.colorbar(im, ax=ax, label="Alignment Score", shrink=0.8)
for entry in tah:
ax.add_patch(plt.Rectangle(
(entry["head"] - 0.5, entry["layer"] - 0.5),
1, 1, fill=False, edgecolor="red", linewidth=1.5,
))
plt.tight_layout()
plt.savefig(args.heatmap, dpi=150)
logger.info("Heatmap saved to %s", args.heatmap)
except Exception as e:
logger.warning("Could not generate heatmap: %s", e)
if __name__ == "__main__":
main()