qwen3 simul+kv: optimized streaming with kv cache reuse

This commit is contained in:
Quentin Fuxa
2026-03-15 18:30:00 +01:00
parent ed503be140
commit b69eaf82be
4 changed files with 4261 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,116 @@
"""
Bridge between WhisperLiveKit STT and IWSLT26 MT pipeline.
Converts streaming ASRToken output from SimulStreaming into the JSONL
format expected by the AlignAtt MT agent (iwslt26-sst).
Output format (one JSON per line):
{"text": "word or phrase", "emission_time": 1.234, "is_final": false, "speech_time": 1.0}
Where:
- text: the emitted word/phrase
- emission_time: wall-clock time when the word was emitted (for compute-aware eval)
- speech_time: timestamp in the audio (for compute-unaware eval)
- is_final: whether this is the last word of a segment/silence boundary
"""
import json
import time
from typing import List, TextIO
from whisperlivekit.timed_objects import ASRToken
class CascadeBridge:
"""Converts ASRToken stream to JSONL for the MT agent."""
def __init__(self, output_file: TextIO = None):
self.output_file = output_file
self.start_time = time.time()
self.entries: List[dict] = []
def emit_tokens(self, tokens: List[ASRToken], is_final: bool = False):
"""Emit a batch of tokens from the STT."""
wall_clock = time.time() - self.start_time
for i, token in enumerate(tokens):
entry = {
"text": token.text.strip(),
"emission_time": round(wall_clock, 3),
"speech_time": round(token.start, 3),
"is_final": is_final and (i == len(tokens) - 1),
}
self.entries.append(entry)
if self.output_file:
self.output_file.write(json.dumps(entry) + "\n")
self.output_file.flush()
def get_entries(self) -> List[dict]:
return self.entries
def get_text(self) -> str:
"""Get the full transcribed text."""
return " ".join(e["text"] for e in self.entries if e["text"])
def save(self, path: str):
"""Save all entries to a JSONL file."""
with open(path, "w") as f:
for entry in self.entries:
f.write(json.dumps(entry) + "\n")
def run_stt_to_jsonl(
audio_path: str,
output_path: str,
model_id: str = "Qwen/Qwen3-ASR-0.6B",
alignment_heads_path: str = None,
border_fraction: float = 0.20,
language: str = "en",
chunk_sec: float = 1.0,
):
"""Run STT on an audio file and save JSONL output for the MT agent.
This is the main entry point for the cascade: audio file → JSONL.
"""
import wave
import numpy as np
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR, Qwen3SimulKVOnlineProcessor
# Load audio
with wave.open(audio_path, 'r') as wf:
audio = np.frombuffer(
wf.readframes(wf.getnframes()), dtype=np.int16
).astype(np.float32) / 32768.0
# Initialize STT
asr = Qwen3SimulKVASR(
model_dir=model_id,
lan=language,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
proc = Qwen3SimulKVOnlineProcessor(asr)
bridge = CascadeBridge()
# Stream audio in chunks
chunk_samples = int(chunk_sec * 16000)
offset = 0
stream_time = 0.0
while offset < len(audio):
chunk = audio[offset:offset + chunk_samples]
stream_time += len(chunk) / 16000
proc.insert_audio_chunk(chunk, stream_time)
words, _ = proc.process_iter(is_last=False)
if words:
bridge.emit_tokens(words, is_final=False)
offset += chunk_samples
# Final flush
final_words, _ = proc.finish()
if final_words:
bridge.emit_tokens(final_words, is_final=True)
# Save
bridge.save(output_path)
return bridge

View File

@@ -126,6 +126,15 @@ class TranscriptionEngine:
self.tokenizer = None
self.asr = Qwen3MLXASR(**transcription_common_params)
logger.info("Using Qwen3 MLX native backend")
elif config.backend == "qwen3-simul-kv":
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR
self.tokenizer = None
self.asr = Qwen3SimulKVASR(
**transcription_common_params,
alignment_heads_path=config.custom_alignment_heads,
border_fraction=getattr(config, 'border_fraction', 0.25),
)
logger.info("Using Qwen3-ASR backend with SimulStreaming+KV policy")
elif config.backend == "qwen3-simul":
from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR
self.tokenizer = None
@@ -235,6 +244,9 @@ def online_factory(args, asr, language=None):
if backend == "vllm-realtime":
from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor
return VLLMRealtimeOnlineProcessor(asr)
if backend == "qwen3-simul-kv":
from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor
return Qwen3SimulKVOnlineProcessor(asr)
if backend == "qwen3-mlx":
from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor
return Qwen3MLXOnlineProcessor(asr)

View File

@@ -0,0 +1,787 @@
"""
Qwen3-ASR SimulStreaming with KV cache reuse.
This is an optimized version of qwen3_simul.py that reuses the KV cache
across inference calls, avoiding redundant prefill of prompt + old audio.
Architecture:
1. First call: full prefill (prompt + audio tokens), greedy decode with
alignment-head stopping, save KV cache + generated tokens
2. Subsequent calls: invalidate KV for old audio suffix, prefill only
new audio tokens, continue decoding from saved state
3. Audio encoder caching: reuse embeddings for stable attention windows
This gives ~3-5x speedup over the original generate()-based approach.
"""
import json
import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
from transformers import DynamicCache
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
@dataclass
class Qwen3SimulKVConfig:
"""Configuration for Qwen3 SimulStreaming with KV cache."""
model_id: str = "Qwen/Qwen3-ASR-1.7B"
alignment_heads_path: Optional[str] = None
language: str = "auto"
border_fraction: float = 0.20
rewind_fraction: float = 0.12
audio_min_len: float = 0.5
audio_max_len: float = 20.0
max_context_tokens: int = 30
init_prompt: Optional[str] = None
max_alignment_heads: int = 10
@dataclass
class _AudioEmbedCache:
"""Cache for audio encoder outputs."""
encoded_samples: int = 0
embeddings: Optional[torch.Tensor] = None
encoded_mel_frames: int = 0
stable_tokens: int = 0
def reset(self):
self.encoded_samples = 0
self.embeddings = None
self.encoded_mel_frames = 0
self.stable_tokens = 0
@dataclass
class Qwen3SimulKVState:
"""Per-session mutable state with KV cache."""
# Audio
audio_buffer: np.ndarray = field(
default_factory=lambda: np.array([], dtype=np.float32)
)
cumulative_time_offset: float = 0.0
global_time_offset: float = 0.0
speaker: int = -1
# KV cache state
kv_cache: Optional[DynamicCache] = None
kv_seq_len: int = 0 # sequence length when KV was saved
prompt_token_count: int = 0 # tokens before audio (system prompt etc)
audio_token_count: int = 0 # audio tokens in the cached KV
generated_token_ids: List[int] = field(default_factory=list)
# Alignment tracking
last_attend_frame: int = -15
committed_text: str = ""
committed_word_count: int = 0
committed_token_ids: List[int] = field(default_factory=list)
# Tracking
first_timestamp: Optional[float] = None
detected_language: Optional[str] = None
last_infer_samples: int = 0
# Audio embedding cache
audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache)
def reset_kv(self):
"""Reset KV cache (e.g., when audio is trimmed from front)."""
self.kv_cache = None
self.kv_seq_len = 0
self.prompt_token_count = 0
self.audio_token_count = 0
self.generated_token_ids = []
class Qwen3SimulKVASR:
"""
Shared backend for Qwen3-ASR SimulStreaming with KV cache reuse.
"""
sep = ""
def __init__(
self,
model_size: str = None,
model_dir: str = None,
lan: str = "auto",
alignment_heads_path: Optional[str] = None,
border_fraction: float = 0.15,
min_chunk_size: float = 0.1,
warmup_file: Optional[str] = None,
model_cache_dir: Optional[str] = None,
model_path: Optional[str] = None,
lora_path: Optional[str] = None,
direct_english_translation: bool = False,
**kwargs,
):
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.warmup_file = warmup_file
self.cfg = Qwen3SimulKVConfig(
language=lan,
alignment_heads_path=alignment_heads_path,
border_fraction=border_fraction,
)
self._load_model(model_size, model_dir, model_cache_dir, model_path)
self.alignment_heads = self._load_alignment_heads(alignment_heads_path)
# Pre-compute heads by layer for efficient hook installation
self.heads_by_layer = {}
for layer_idx, head_idx in self.alignment_heads:
self.heads_by_layer.setdefault(layer_idx, []).append(head_idx)
if warmup_file:
from whisperlivekit.warmup import load_file
audio = load_file(warmup_file)
if audio is not None:
self._warmup(audio)
def _load_model(self, model_size, model_dir, model_cache_dir, model_path):
from whisperlivekit.qwen3_asr import QWEN3_MODEL_MAPPING, _patch_transformers_compat
_patch_transformers_compat()
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
if model_dir:
model_id = model_dir
elif model_path:
model_id = model_path
elif model_size:
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
else:
model_id = "Qwen/Qwen3-ASR-1.7B"
if torch.cuda.is_available():
dtype, device = torch.bfloat16, "cuda:0"
else:
dtype, device = torch.float32, "cpu"
logger.info("Loading Qwen3-ASR for SimulStreaming+KV: %s", model_id)
self.model = AutoModel.from_pretrained(model_id, dtype=dtype, device_map=device)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True)
thinker = self.model.thinker
text_config = thinker.config.text_config
self.num_layers = text_config.num_hidden_layers
self.num_heads = text_config.num_attention_heads
self.num_kv_heads = text_config.num_key_value_heads
self.audio_token_id = thinker.config.audio_token_id
self.device = next(self.model.parameters()).device
self.dtype = next(self.model.parameters()).dtype
self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("<asr_text>")
# EOS tokens
self.eos_ids = {151645, 151643}
if self.processor.tokenizer.eos_token_id is not None:
self.eos_ids.add(self.processor.tokenizer.eos_token_id)
logger.info(
"Qwen3-ASR loaded: %d layers x %d heads, device=%s",
self.num_layers, self.num_heads, self.device,
)
def _load_alignment_heads(self, path):
max_heads = self.cfg.max_alignment_heads
if path and Path(path).exists():
with open(path) as f:
data = json.load(f)
all_heads = [tuple(h) for h in data["alignment_heads_compact"]]
heads = all_heads[:max_heads]
logger.info("Loaded top %d alignment heads from %s", len(heads), path)
return heads
default_heads = []
start_layer = self.num_layers * 3 // 4
for layer in range(start_layer, self.num_layers):
for head in range(self.num_heads):
default_heads.append((layer, head))
logger.warning("No alignment heads file. Using %d default heads.", len(default_heads))
return default_heads[:max_heads]
def _warmup(self, audio):
try:
audio = audio[:SAMPLE_RATE * 2]
msgs = [{"role": "system", "content": ""}, {"role": "user", "content": [{"type": "audio", "audio": ""}]}]
text_prompt = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
inputs = self.processor(text=[text_prompt], audio=[audio], return_tensors="pt", padding=True)
inputs = inputs.to(self.device).to(self.dtype)
with torch.inference_mode():
self.model.thinker.generate(**inputs, max_new_tokens=5, do_sample=False)
logger.info("Warmup complete")
except Exception as e:
logger.warning("Warmup failed: %s", e)
def transcribe(self, audio):
pass
class Qwen3SimulKVOnlineProcessor:
"""
Per-session online processor with KV cache reuse.
Key optimization: instead of calling generate() each time (which does
full prefill), we maintain a DynamicCache and do incremental prefill
+ manual greedy decoding with alignment head hooks.
"""
SAMPLING_RATE = 16000
MIN_DURATION_REAL_SILENCE = 5
def __init__(self, asr: Qwen3SimulKVASR, logfile=sys.stderr):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer: List[ASRToken] = []
self.state = Qwen3SimulKVState()
self._build_prompt_template()
def _build_prompt_template(self):
from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE
msgs = [
{"role": "system", "content": ""},
{"role": "user", "content": [{"type": "audio", "audio": ""}]},
]
self._base_prompt = self.asr.processor.apply_chat_template(
msgs, add_generation_prompt=True, tokenize=False,
)
lan = self.asr.cfg.language
if lan and lan != "auto":
lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan)
self._base_prompt += f"language {lang_name}<asr_text>"
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self.state.audio_buffer = np.append(self.state.audio_buffer, audio)
max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE)
if len(self.state.audio_buffer) > max_samples:
trim = len(self.state.audio_buffer) - max_samples
self.state.audio_buffer = self.state.audio_buffer[trim:]
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
self.state.audio_cache.reset()
self.state.reset_kv() # Must invalidate KV when audio is trimmed
def start_silence(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end
def end_silence(self, silence_duration: float, offset: float):
self.end += silence_duration
long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE
if not long_silence:
gap_len = int(self.SAMPLING_RATE * silence_duration)
if gap_len > 0:
self.state.audio_buffer = np.append(
self.state.audio_buffer, np.zeros(gap_len, dtype=np.float32),
)
else:
self.state = Qwen3SimulKVState()
self.state.global_time_offset = silence_duration + offset
def new_speaker(self, change_speaker: ChangeSpeaker):
self.process_iter(is_last=True)
self.state = Qwen3SimulKVState()
self.state.speaker = change_speaker.speaker
self.state.global_time_offset = change_speaker.start
def get_buffer(self) -> Transcript:
return Transcript.from_tokens(tokens=self.buffer, sep='')
def _encode_audio(self) -> Tuple[torch.Tensor, int]:
"""Encode full audio buffer, with caching for stable windows."""
asr = self.asr
state = self.state
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
feat_out = asr.processor.feature_extractor(
[state.audio_buffer], sampling_rate=16000,
padding=True, truncation=False,
return_attention_mask=True, return_tensors="pt",
)
input_features = feat_out["input_features"].to(asr.device).to(asr.dtype)
feature_attention_mask = feat_out["attention_mask"].to(asr.device)
total_mel_frames = feature_attention_mask.sum().item()
total_audio_tokens = _get_feat_extract_output_lengths(
torch.tensor(total_mel_frames),
).item()
cache = state.audio_cache
audio_cfg = asr.model.thinker.audio_tower.config
n_window_infer = getattr(audio_cfg, "n_window_infer", 400)
n_complete_windows = total_mel_frames // n_window_infer
if n_complete_windows <= 0 or cache.embeddings is None:
# Full encode
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
stable_mel = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel),
).item() if stable_mel > 0 else 0
else:
stable_mel = n_complete_windows * n_window_infer
stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel),
).item()
if cache.stable_tokens > 0 and cache.stable_tokens <= stable_tokens:
cached_prefix = cache.embeddings[:stable_tokens] if cache.embeddings.dim() == 2 else cache.embeddings[0, :stable_tokens]
tail_features = input_features[:, :, stable_mel:]
tail_mel_frames = total_mel_frames - stable_mel
if tail_mel_frames > 0:
tail_mask = torch.ones(
(1, tail_features.shape[2]),
dtype=feature_attention_mask.dtype,
device=feature_attention_mask.device,
)
tail_embeds = asr.model.thinker.get_audio_features(
tail_features, feature_attention_mask=tail_mask,
)
if tail_embeds.dim() == 3:
tail_embeds = tail_embeds[0]
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
else:
audio_embeds = cached_prefix
else:
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
# Update cache
cache.embeddings = audio_embeds if audio_embeds.dim() == 2 else audio_embeds[0]
cache.encoded_samples = len(state.audio_buffer)
cache.encoded_mel_frames = total_mel_frames
stable_mel_final = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0
cache.stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel_final),
).item() if stable_mel_final > 0 else 0
return audio_embeds, total_audio_tokens
def _build_full_inputs(self, audio_embeds: torch.Tensor) -> dict:
"""Build full input embeddings from prompt + audio embeddings + context."""
asr = self.asr
state = self.state
thinker = asr.model.thinker
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
n_audio_tokens = audio_embeds.shape[0]
prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens(
[self._base_prompt], iter([n_audio_tokens]),
)[0]
text_ids = asr.processor.tokenizer(
[prompt_with_placeholders], return_tensors="pt", padding=True,
)
input_ids = text_ids["input_ids"].to(asr.device)
attention_mask = text_ids.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(asr.device)
# Append committed context tokens
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
ctx_ids = torch.tensor([ctx], dtype=input_ids.dtype, device=input_ids.device)
input_ids = torch.cat([input_ids, ctx_ids], dim=1)
if attention_mask is not None:
ctx_mask = torch.ones_like(ctx_ids)
attention_mask = torch.cat([attention_mask, ctx_mask], dim=1)
# Build inputs_embeds
inputs_embeds = thinker.get_input_embeddings()(input_ids)
audio_mask = (input_ids == asr.audio_token_id)
n_placeholders = audio_mask.sum().item()
if n_placeholders != n_audio_tokens:
logger.warning("Audio token mismatch: %d vs %d", n_placeholders, n_audio_tokens)
return None
audio_embeds_cast = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(expand_mask, audio_embeds_cast)
# Find audio token range
audio_positions = audio_mask[0].nonzero(as_tuple=True)[0]
audio_start = audio_positions[0].item()
audio_end = audio_positions[-1].item() + 1
return {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"audio_start": audio_start,
"audio_end": audio_end,
"n_audio_tokens": n_audio_tokens,
}
@torch.inference_mode()
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE
if audio_duration < self.asr.cfg.audio_min_len:
return [], self.end
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
min_new_seconds = 1.0
if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE):
return [], self.end
self.state.last_infer_samples = len(self.state.audio_buffer)
try:
timestamped_words = self._infer(is_last)
except Exception as e:
logger.exception("Inference error: %s", e)
self.state.reset_kv()
return [], self.end
if not timestamped_words:
return [], self.end
self.buffer = []
return timestamped_words, self.end
def _infer(self, is_last: bool) -> List[ASRToken]:
"""Run inference with KV cache reuse and alignment-head stopping."""
asr = self.asr
state = self.state
thinker = asr.model.thinker
# Step 1: Encode audio (with caching)
audio_embeds, n_audio_tokens_total = self._encode_audio()
# Step 2: Build full inputs
full_inputs = self._build_full_inputs(audio_embeds)
if full_inputs is None:
state.reset_kv()
return []
input_ids = full_inputs["input_ids"]
inputs_embeds = full_inputs["inputs_embeds"]
attention_mask = full_inputs["attention_mask"]
audio_start = full_inputs["audio_start"]
audio_end = full_inputs["audio_end"]
n_audio_tokens = full_inputs["n_audio_tokens"]
audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE
# Step 3: Full prefill (we always re-prefill since audio tokens change)
# Future optimization: partial prefill when only tail audio changes
out = thinker(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
use_cache=True,
)
kv_cache = out.past_key_values
prompt_len = input_ids.shape[1]
# Step 4: Greedy decode with alignment head stopping
border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction))
rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction))
last_attend_frame = state.last_attend_frame
# Install hooks for alignment head attention extraction
decoder_layers = thinker.model.layers
num_kv_heads = asr.num_kv_heads
num_heads = asr.num_heads
gqa_ratio = num_heads // num_kv_heads
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import apply_rotary_pos_emb
per_step_frames: List[List[int]] = []
current_step_frames: List[int] = []
hooks = []
def _make_attn_hook(layer_idx):
head_indices = asr.heads_by_layer[layer_idx]
def hook_fn(module, args, kwargs, output):
hidden_states = kwargs.get('hidden_states')
if hidden_states is None:
hidden_states = args[0] if args else None
if hidden_states is None or hidden_states.shape[1] != 1:
return
position_embeddings = kwargs.get('position_embeddings')
if position_embeddings is None and len(args) > 1:
position_embeddings = args[1]
past_kv = kwargs.get('past_key_values')
if position_embeddings is None or past_kv is None:
return
hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim)
q = module.q_norm(module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
cos, sin = position_embeddings
q, _ = apply_rotary_pos_emb(q, q, cos, sin)
cache_layer = past_kv.layers[module.layer_idx]
k = cache_layer.keys
if k is None or audio_end > k.shape[2]:
return
for h_idx in head_indices:
if h_idx >= q.shape[1]:
continue
kv_h_idx = h_idx // gqa_ratio
q_h = q[0, h_idx, 0]
k_audio = k[0, kv_h_idx, audio_start:audio_end]
scores = torch.matmul(k_audio, q_h)
frame = scores.argmax().item()
current_step_frames.append(frame)
return hook_fn
for layer_idx in asr.heads_by_layer:
if layer_idx < len(decoder_layers):
h = decoder_layers[layer_idx].self_attn.register_forward_hook(
_make_attn_hook(layer_idx), with_kwargs=True,
)
hooks.append(h)
try:
# Greedy decoding with alignment-based stopping
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
generated_ids = []
border_stop_step = None
tokens_per_sec = 6
if is_last:
max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120)
else:
new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE
max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40)
for step in range(max_tokens):
tid = next_token.item()
if tid in asr.eos_ids:
break
generated_ids.append(tid)
# Collect alignment frames for this step
if current_step_frames:
per_step_frames.append(current_step_frames)
current_step_frames = []
# Check stopping criteria (after 3 tokens)
if not is_last and len(per_step_frames) >= 3:
latest = per_step_frames[-1]
if latest:
frames_sorted = sorted(latest)
attended = frames_sorted[len(frames_sorted) // 2]
if last_attend_frame - attended > rewind_threshold:
border_stop_step = max(0, len(per_step_frames) - 2)
break
last_attend_frame = attended
if (n_audio_tokens - attended) <= border_threshold:
border_stop_step = len(per_step_frames) - 1
break
# Next token
out = thinker(
input_ids=next_token,
past_key_values=kv_cache,
use_cache=True,
)
kv_cache = out.past_key_values
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
# Flush remaining frames
if current_step_frames:
per_step_frames.append(current_step_frames)
finally:
for h in hooks:
h.remove()
state.last_attend_frame = last_attend_frame
if not generated_ids:
return []
# Strip metadata prefix (<asr_text> token)
all_generated = torch.tensor(generated_ids, device=asr.device)
num_gen = len(generated_ids)
asr_text_id = asr.asr_text_token_id
metadata_offset = 0
for i in range(min(num_gen, 10)):
if generated_ids[i] == asr_text_id:
if state.detected_language is None and i > 0:
from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE
prefix_text = asr.processor.tokenizer.decode(
generated_ids[:i], skip_special_tokens=True,
).strip()
parts = prefix_text.split()
if len(parts) >= 2:
lang_name = parts[-1]
if lang_name.lower() != "none":
state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get(
lang_name, lang_name.lower(),
)
metadata_offset = i + 1
break
if metadata_offset > 0:
generated_ids = generated_ids[metadata_offset:]
num_gen -= metadata_offset
per_step_frames = per_step_frames[metadata_offset:]
if num_gen <= 0:
return []
# Determine emit count
if border_stop_step is not None:
emit_up_to = min(border_stop_step, num_gen)
else:
emit_up_to = num_gen
emitted_ids = generated_ids[:emit_up_to]
if not emitted_ids:
return []
# Build timestamped words
words = self._build_timestamped_words(
emitted_ids, per_step_frames, emit_up_to,
n_audio_tokens, audio_duration,
)
state.committed_word_count += len(words)
# Include metadata in committed tokens for context
all_emitted = generated_ids[:emit_up_to]
if metadata_offset > 0:
all_emitted = generated_ids[:emit_up_to] # already stripped
state.committed_token_ids.extend(all_emitted)
return words
def _build_timestamped_words(
self,
generated_ids: list,
step_frames: List[List[int]],
emit_up_to: int,
n_audio_tokens: int,
audio_duration: float,
) -> List[ASRToken]:
asr = self.asr
state = self.state
per_token_frame = []
for step in range(emit_up_to):
if step < len(step_frames) and step_frames[step]:
frames = sorted(step_frames[step])
per_token_frame.append(frames[len(frames) // 2])
else:
per_token_frame.append(None)
tokenizer = asr.processor.tokenizer
full_text = tokenizer.decode(generated_ids[:emit_up_to], skip_special_tokens=True)
text_words = full_text.split()
all_frames = [f for f in per_token_frame if f is not None]
words = []
for wi, word in enumerate(text_words):
if all_frames:
frac = wi / max(len(text_words), 1)
frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1)
frame = all_frames[frame_idx]
else:
frame = None
words.append((word, frame))
tokens = []
for i, (text, frame) in enumerate(words):
text = text.strip()
if not text:
continue
if frame is not None and n_audio_tokens > 0:
timestamp = (
frame / n_audio_tokens * audio_duration
+ state.cumulative_time_offset
)
else:
timestamp = (
(i / max(len(words), 1)) * audio_duration
+ state.cumulative_time_offset
)
is_very_first_word = (i == 0 and state.committed_word_count == 0)
display_text = text if is_very_first_word else " " + text
token = ASRToken(
start=round(timestamp, 2),
end=round(timestamp + 0.1, 2),
text=display_text,
speaker=state.speaker,
detected_language=state.detected_language,
).with_offset(state.global_time_offset)
tokens.append(token)
return tokens
def warmup(self, audio: np.ndarray, init_prompt: str = ""):
try:
self.state.audio_buffer = audio[:SAMPLE_RATE]
self.process_iter(is_last=True)
self.state = Qwen3SimulKVState()
except Exception as e:
logger.warning("Warmup failed: %s", e)
self.state = Qwen3SimulKVState()
def finish(self) -> Tuple[List[ASRToken], float]:
all_tokens = []
for _ in range(5):
tokens, _ = self.process_iter(is_last=True)
if not tokens:
break
all_tokens.extend(tokens)
return all_tokens, self.end