5 Commits

Author SHA1 Message Date
Quentin Fuxa
a246ba9bfe v0 2025-11-09 22:02:15 +01:00
Quentin Fuxa
7108d2ddc5 fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/269 2025-11-09 20:08:18 +01:00
Quentin Fuxa
a732e0903e Add a script to detect alignement heads, usefull for distilled whisper 2025-11-09 18:12:09 +01:00
Quentin Fuxa
0491681be4 Distilled model compatibility with HF config.json to ModelDimensions 2025-11-08 20:20:05 +01:00
Quentin Fuxa
ffe5284764 _processing_tasks_done checks task completion 2025-11-05 23:34:00 +01:00
12 changed files with 711 additions and 116 deletions

View File

@@ -171,7 +171,8 @@ async def websocket_endpoint(websocket: WebSocket):
| SimulStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
| `None` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |

View File

@@ -3,12 +3,14 @@
The `--model-path` parameter accepts:
## File Path
- **`.pt` format only** (required for AlignAtt policy decoder)
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
## Directory Path (recommended)
Must contain:
- **`.pt` file** (required for decoder)
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
May optionally contain:
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.13"
version = "0.2.13.post1"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -30,7 +30,6 @@ dependencies = [
"fastapi",
"librosa",
"soundfile",
"faster-whisper",
"uvicorn",
"websockets",
"torchaudio>=2.0.0",

BIN
scripts/alignment_heads.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

View File

@@ -0,0 +1,292 @@
"""Determine alignment heads for a variants, such as distilled model"""
from __future__ import annotations
import argparse
import base64
import gzip
import io
import pathlib
import sys
import math
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from datasets import Audio as DatasetAudio, load_dataset
import soundfile as sf
import matplotlib.pyplot as plt
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
WHISPER_ROOT = REPO_ROOT / "whisper"
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(WHISPER_ROOT))
from whisper import load_model
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
def load_dataset_clips(name, config, split, limit):
ds = load_dataset(name, config, split=split)
ds = ds.cast_column("audio", DatasetAudio(decode=False))
clips = []
for idx, row in enumerate(ds):
if limit is not None and idx >= limit:
break
audio_field = row["audio"]
transcript = row["text"]
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
if waveform_np.ndim > 1:
waveform_np = waveform_np.mean(axis=1)
waveform = waveform_np
transcript = str(transcript)
clips.append((waveform, transcript))
return clips
def load_clips(args):
return load_dataset_clips(
args.dataset,
args.dataset_config,
args.dataset_split,
args.dataset_num_samples,
)
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
return waveform
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="pytorch_model.bin",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Torch device to run on",
)
parser.add_argument(
"--dataset",
type=str,
default="librispeech_asr"
)
parser.add_argument(
"--dataset-config",
type=str,
default="clean"
)
parser.add_argument(
"--dataset-split",
type=str,
default="validation[:1%]",
)
parser.add_argument(
"--dataset-num-samples",
type=int,
default=16,
)
parser.add_argument(
"--threshold",
type=float,
default=1.5,
help="Z score threshold for a head to be selected",
)
parser.add_argument(
"--votes",
type=float,
default=0.75,
help="percentage of clips that must vote for a head",
)
parser.add_argument(
"--output",
type=str,
default="alignment_heads.b85",
)
parser.add_argument(
"--visualize-top-k",
type=int,
default=32,
)
return parser.parse_args()
def collect_heads(
model,
tokenizer,
clips: Sequence[Tuple[AudioInput, str]],
threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
strengths = torch.zeros_like(votes)
for audio_source, transcript in clips:
waveform = pad_or_trim(_waveform_from_source(audio_source))
mel = log_mel_spectrogram(waveform, device=device)
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=device,
)
qks = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
for layer_idx, tensor in enumerate(qks):
if tensor is None:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1)
peak = tensor.max(dim=-1).values # [heads, tokens]
strengths[layer_idx] += peak.mean(dim=-1)
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
)
mask = (zscore > 3).any(dim=-1)
votes[layer_idx] += mask.float()
votes /= len(clips)
strengths /= len(clips)
return votes, strengths
def _select_heads_for_visualization(selection, strengths, top_k):
selected = torch.nonzero(selection, as_tuple=False)
if selected.numel() == 0:
return []
entries = [
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
for layer, head in selected
]
entries.sort(key=lambda item: item[2], reverse=True)
return entries[:top_k]
def _extract_heatmaps(
model,
tokenizer,
clip: Tuple[AudioInput, str],
heads: Sequence[Tuple[int, int, float]],
) -> dict:
if not heads:
return {}
target_map = {}
for layer, head, _ in heads:
target_map.setdefault(layer, set()).add(head)
waveform = pad_or_trim(_waveform_from_source(clip[0]))
mel = log_mel_spectrogram(waveform, device=model.device)
transcript = clip[1]
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*tokenizer.encode(transcript),
tokenizer.eot,
],
device=model.device,
)
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
model(mel.unsqueeze(0), tokens.unsqueeze(0))
for hook in hooks:
hook.remove()
heatmaps = {}
for layer_idx, tensor in enumerate(QKs):
if tensor is None or layer_idx not in target_map:
continue
tensor = tensor[:, :, : mel.shape[-1] // 2]
tensor = tensor.softmax(dim=-1).cpu()
for head_idx in target_map[layer_idx]:
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
return heatmaps
def _plot_heatmaps(
heads, heatmaps, output_path):
cols = min(3, len(heads))
rows = math.ceil(len(heads) / cols)
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
for idx, (layer, head, score) in enumerate(heads):
ax = axes[idx // cols][idx % cols]
mat = heatmaps.get((layer, head))
if mat is None:
ax.axis("off")
continue
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
ax.set_xlabel("time")
ax.set_ylabel("tokens")
for j in range(len(heads), rows * cols):
axes[j // cols][j % cols].axis("off")
fig.tight_layout()
fig.savefig(output_path, dpi=200)
plt.close(fig)
def _dump_mask(mask: torch.Tensor, output_path: str):
payload = mask.numpy().astype(np.bool_)
blob = base64.b85encode(gzip.compress(payload.tobytes()))
with open(output_path, "wb") as f:
f.write(blob)
def main():
args = _parse_args()
model = load_model(args.model, device=args.device)
model.eval()
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
clips = load_clips(args)
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
# selection = votes > 0.5
selection = strengths > 0.05
_dump_mask(selection.cpu(), args.output)
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,10 @@
"""Copy core files from web directory to Chrome extension directory."""
import shutil
import os
from pathlib import Path
def sync_extension_files():
"""Copy core files from web directory to Chrome extension directory."""
web_dir = Path("whisperlivekit/web")
extension_dir = Path("chrome-extension")

View File

@@ -464,7 +464,7 @@ class AudioProcessor:
yield response
self.last_response_content = response
if self.is_stopping and self.transcription_task and self.transcription_task.done() and self.diarization_task and self.diarization_task.done():
if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return
@@ -517,11 +517,16 @@ class AudioProcessor:
async def watchdog(self, tasks_to_monitor):
"""Monitors the health of critical processing tasks."""
tasks_remaining = [task for task in tasks_to_monitor if task]
while True:
try:
if not tasks_remaining:
logger.info("Watchdog task finishing: all monitored tasks completed.")
return
await asyncio.sleep(10)
for i, task in enumerate(tasks_to_monitor):
for i, task in enumerate(list(tasks_remaining)):
if task.done():
exc = task.exception()
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
@@ -529,6 +534,7 @@ class AudioProcessor:
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
else:
logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task)
except asyncio.CancelledError:
logger.info("Watchdog task cancelled.")
@@ -559,6 +565,16 @@ class AudioProcessor:
self.diarization.close()
logger.info("AudioProcessor cleanup complete.")
def _processing_tasks_done(self):
"""Return True when all active processing tasks have completed."""
tasks_to_check = [
self.transcription_task,
self.diarization_task,
self.translation_task,
self.ffmpeg_reader_task,
]
return all(task.done() for task in tasks_to_check if task)
async def process_audio(self, message):
"""Process incoming audio data."""

View File

@@ -81,7 +81,8 @@ def no_token_to_silence(tokens):
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
current_time = time() - (beg_loop if beg_loop else 0.0)
last_token = tokens[-1]
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
silence_duration = current_time - last_token.end
if (vac_detected_silence and silence_duration > END_SILENCE_DURATION_VAC) or (silence_duration >= END_SILENCE_DURATION):
if last_token.speaker == -2:
last_token.end = current_time
else:
@@ -102,5 +103,4 @@ def handle_silences(tokens, beg_loop, vac_detected_silence):
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
tokens = no_token_to_silence(tokens)
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
return tokens
return tokens

View File

@@ -1,7 +1,6 @@
import logging
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time
from whisperlivekit.timed_objects import Line, Segment, format_time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -72,83 +71,96 @@ def format_output(state, silence, args, sep):
token.corrected_speaker = 1
token.validated_speaker = True
else:
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if token.speaker != previous_speaker:
if is_punctuation(token):
last_punctuation = i
if last_punctuation == i-1:
if token.speaker != previous_speaker:
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
# perfect, diarization perfectly aligned
last_punctuation = None
else:
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
if speaker_change_pos:
# Corrects delay:
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
# should become:
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
token.corrected_speaker = new_speaker
token.validated_speaker = True
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
elif speaker != previous_speaker:
if not (speaker == -2 or previous_speaker == -2):
if next_punctuation_change(i, tokens):
# Corrects advance:
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
# should become:
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = True
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
if not disable_punctuation_split:
token.corrected_speaker = previous_speaker
token.validated_speaker = False
token.validated_speaker = False
if token.validated_speaker:
state.last_validated_token = i
state.last_validated_token = i + last_validated_token
previous_speaker = token.corrected_speaker
previous_speaker = 1
for token in tokens[last_validated_token+1:state.last_validated_token+1]:
if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
state.segments.append(
Segment(
speaker=token.corrected_speaker,
words=[token]
)
)
else:
state.segments[-1].words.append(token)
for token in tokens[state.last_validated_token+1:]:
# if not state.segments or int(token.corrected_speaker) != int(state.segments[-1].speaker):
# state.segments.append(
# Segment(
# speaker=token.corrected_speaker,
# buffer_tokens=[token]
# )
# )
# else:
state.segments[-1].buffer_tokens.append(token)
for segment in state.segments:
segment.consolidate(sep)
# lines = []
# for token in tokens:
# if int(token.corrected_speaker) != int(previous_speaker):
# lines.append(new_line(token))
# else:
# append_token_to_last_line(lines, sep, token)
# previous_speaker = token.corrected_speaker
for ts in translation_validated_segments:
for segment in state.segments[state.last_validated_segment:]:
if ts.is_within(segment):
segment.translation += ts.text + sep
break
for ts in translation_buffer:
for segment in state.segments[state.last_validated_segment:]:
if ts.is_within(segment):
segment.buffer.translation += ts.text + sep
break
# if state.buffer_transcription and lines:
# lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
lines = []
for token in tokens:
if int(token.corrected_speaker) != int(previous_speaker):
lines.append(new_line(token))
else:
append_token_to_last_line(lines, sep, token)
previous_speaker = token.corrected_speaker
if lines:
unassigned_translated_segments = []
for ts in translation_validated_segments:
assigned = False
for line in lines:
if ts and ts.overlaps_with(line):
if ts.is_within(line):
line.translation += ts.text + ' '
assigned = True
break
else:
ts0, ts1 = ts.approximate_cut_at(line.end)
if ts0 and line.overlaps_with(ts0):
line.translation += ts0.text + ' '
if ts1:
unassigned_translated_segments.append(ts1)
assigned = True
break
if not assigned:
unassigned_translated_segments.append(ts)
if unassigned_translated_segments:
for line in lines:
remaining_segments = []
for ts in unassigned_translated_segments:
if ts and ts.overlaps_with(line):
line.translation += ts.text + ' '
else:
remaining_segments.append(ts)
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
for segment in state.segments:
lines.append(Line(
start=segment.start,
end=segment.end,
speaker=segment.speaker,
text=segment.text,
translation=segment.translation
))
if state.buffer_transcription and lines:
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
return lines, undiarized_text
return lines, undiarized_text

View File

@@ -23,7 +23,7 @@ try:
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: `pip install mlx-whisper`\n{"="*50}""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
@@ -32,6 +32,8 @@ else:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
if platform.system() != "Darwin":
print(f"""{"="*50}\nFaster-Whisper not found but. Consider installing faster-whisper for better performance: `pip install faster-whisper`\n{"="*50}`""")
HAS_FASTER_WHISPER = False
def model_path_and_type(model_path):
@@ -39,9 +41,10 @@ def model_path_and_type(model_path):
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if path.is_dir():
pytorch_path = None
if path.is_file() and path.suffix.lower() in ['.pt', '.safetensors', '.bin']:
pytorch_path = path
elif path.is_dir():
for file in path.iterdir():
if file.is_file():
if file.name in ['weights.npz', "weights.safetensors"]:
@@ -49,8 +52,13 @@ def model_path_and_type(model_path):
elif file.suffix.lower() == '.bin':
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
pytorch_path = file
elif file.suffix.lower() == '.safetensors':
pytorch_path = file
if pytorch_path is None:
if (model_path / Path("pytorch_model.bin")).exists():
pytorch_path = model_path / Path("pytorch_model.bin")
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
class SimulStreamingOnlineProcessor:
@@ -169,11 +177,11 @@ class SimulStreamingASR():
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
self.model_name = self.pytorch_path.stem
is_multilingual = not self.model_path.endswith(".en")
elif self.model_size is not None:
model_mapping = {
'tiny': './tiny.pt',
@@ -189,12 +197,11 @@ class SimulStreamingASR():
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = pt_path.name.replace(".pt", "")
self.model_name = self.model_size
is_multilingual = not self.model_name.endswith(".en")
self.cfg = AlignAttConfig(
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold,
language=self.lan,
@@ -226,9 +233,10 @@ class SimulStreamingASR():
if self.model_path and compatible_whisper_mlx:
mlx_model = self.model_path
else:
mlx_model = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
mlx_model = mlx_model_mapping.get(self.model_name)
if mlx_model:
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
print('Simulstreaming will use Faster Whisper for the encoder.')
if self.model_path and compatible_faster_whisper:
@@ -247,7 +255,7 @@ class SimulStreamingASR():
def load_model(self):
whisper_model = load_model(
name=self.model_path if self.model_path else self.model_name,
name=self.pytorch_path if self.pytorch_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads

View File

@@ -1,12 +1,14 @@
import hashlib
import io
import json
import os
import urllib
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict
import torch
from tqdm import tqdm
from pathlib import Path
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
@@ -100,6 +102,137 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
"""
attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models
"""
candidates = []
if os.path.isdir(path):
candidates.append(os.path.join(path, "config.json"))
else:
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
for candidate in candidates:
if not os.path.isfile(candidate):
continue
with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f)
try:
return ModelDimensions(
n_mels=config["num_mel_bins"],
n_audio_ctx=config["max_source_positions"],
n_audio_state=config["d_model"],
n_audio_head=config["encoder_attention_heads"],
n_audio_layer=config.get("encoder_layers")
or config["num_hidden_layers"],
n_vocab=config["vocab_size"],
n_text_ctx=config["max_target_positions"],
n_text_state=config["d_model"],
n_text_head=config["decoder_attention_heads"],
n_text_layer=config["decoder_layers"],
)
except KeyError as err:
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
return None
return None
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
converts a HF checkpoint state_dict into the naming convention used by
default whisper
"""
if not any(k.startswith("model.") for k in state_dict):
return state_dict
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
if remainder.startswith("self_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "attn.query",
"k_proj": "attn.key",
"v_proj": "attn.value",
"out_proj": "attn.out",
}
stem = mapping.get(suffix.split(".")[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "self_attn_layer_norm.weight":
return f"{target_prefix}.attn_ln.weight"
elif remainder == "self_attn_layer_norm.bias":
return f"{target_prefix}.attn_ln.bias"
elif remainder.startswith("encoder_attn."):
suffix = remainder.split(".", 1)[1]
mapping = {
"q_proj": "cross_attn.query",
"k_proj": "cross_attn.key",
"v_proj": "cross_attn.value",
"out_proj": "cross_attn.out",
}
stem = mapping.get(suffix.split(".", 1)[0])
if stem:
rest = suffix.split(".", 1)[1] if "." in suffix else ""
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
elif remainder == "encoder_attn_layer_norm.weight":
return f"{target_prefix}.cross_attn_ln.weight"
elif remainder == "encoder_attn_layer_norm.bias":
return f"{target_prefix}.cross_attn_ln.bias"
elif remainder.startswith("fc1."):
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
elif remainder.startswith("fc2."):
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
elif remainder == "final_layer_norm.weight":
return f"{target_prefix}.mlp_ln.weight"
elif remainder == "final_layer_norm.bias":
return f"{target_prefix}.mlp_ln.bias"
return None
converted = {}
for key, value in state_dict.items():
if not key.startswith("model."):
continue
subkey = key[len("model.") :]
if subkey.startswith("encoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("decoder.layers."):
parts = subkey.split(".")
layer_idx = parts[2]
remainder = ".".join(parts[3:])
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
mapped = subkey
elif subkey == "encoder.embed_positions.weight":
mapped = "encoder.positional_embedding"
elif subkey == "decoder.embed_positions.weight":
mapped = "decoder.positional_embedding"
elif subkey == "encoder.layer_norm.weight":
mapped = "encoder.ln_post.weight"
elif subkey == "encoder.layer_norm.bias":
mapped = "encoder.ln_post.bias"
elif subkey.startswith("decoder.embed_tokens."):
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
elif subkey == "decoder.layer_norm.weight":
mapped = "decoder.ln.weight"
elif subkey == "decoder.layer_norm.bias":
mapped = "decoder.ln.bias"
else:
mapped = None
if mapped:
converted[mapped] = value
return converted if converted else state_dict
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
@@ -134,7 +267,6 @@ def load_model(
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
@@ -148,22 +280,50 @@ def load_model(
if custom_alignment_heads:
alignment_heads = custom_alignment_heads.encode()
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
try:
from safetensors.torch import load_file
except ImportError:
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
if in_memory:
checkpoint = load_file(checkpoint_file, device=device)
else:
checkpoint = load_file(checkpoint_file, device=device)
else:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
state_dict = _convert_hf_state_dict(state_dict)
if dims_cfg is not None:
dims = ModelDimensions(**dims_cfg)
else:
dims = _infer_dims_from_config(name)
if dims is None:
raise RuntimeError(
"Could not determine model dimensions. "
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
)
if not isinstance(state_dict, dict):
state_dict = checkpoint
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
state_dict = {
k: v for k, v in state_dict.items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
model.load_state_dict(state_dict)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

View File

@@ -91,8 +91,12 @@ class SpeakerSegment(TimedText):
@dataclass
class Translation(TimedText):
is_validated : bool = False
pass
# def split(self):
# return self.text.split(" ") # should be customized with the sep
def approximate_cut_at(self, cut_time):
"""
Each word in text is considered to be of duration (end-start)/len(words in text)
@@ -120,6 +124,19 @@ class Translation(TimedText):
return segment0, segment1
def cut_position(self, position):
sep=" "
words = self.text.split(sep)
num_words = len(words)
duration_per_word = self.duration() / num_words
cut_time=duration_per_word*position
text0 = sep.join(words[:position])
text1 = sep.join(words[position:])
segment0 = Translation(start=self.start, end=cut_time, text=text0)
segment1 = Translation(start=cut_time, end=self.end, text=text1)
return segment0, segment1
@dataclass
class Silence():
@@ -143,6 +160,90 @@ class Line(TimedText):
_dict['detected_language'] = self.detected_language
return _dict
@dataclass
class WordValidation:
"""Validation status for word-level data."""
text: bool = False
speaker: bool = False
language: bool = False
def to_dict(self):
return {
'text': self.text,
'speaker': self.speaker,
'language': self.language
}
@dataclass
class Word:
"""Word-level object with timing and validation information."""
text: str = ''
start: float = 0.0
end: float = 0.0
validated: WordValidation = field(default_factory=WordValidation)
def to_dict(self):
return {
'text': self.text,
'start': self.start,
'end': self.end,
'validated': self.validated.to_dict()
}
@dataclass
class SegmentBuffer:
"""Per-segment temporary buffers for ephemeral data."""
transcription: str = ''
diarization: str = ''
translation: str = ''
def to_dict(self):
return {
'transcription': self.transcription,
'diarization': self.diarization,
'translation': self.translation
}
@dataclass
class Segment:
"""Represents a segment in the new API structure."""
id: int = 0
speaker: int = -1
text: str = ''
start_speaker: float = 0.0
start: float = 0.0
end: float = 0.0
language: Optional[str] = None
translation: str = ''
words: List[ASRToken] = field(default_factory=list)
buffer_tokens: List[ASRToken] = field(default_factory=list)
buffer_translation = ''
buffer: SegmentBuffer = field(default_factory=SegmentBuffer)
def to_dict(self):
"""Convert segment to dictionary for JSON serialization."""
return {
'id': self.id,
'speaker': self.speaker,
'text': self.text,
'start_speaker': self.start_speaker,
'start': self.start,
'end': self.end,
'language': self.language,
'translation': self.translation,
'words': [word.to_dict() for word in self.words],
'buffer': self.buffer.to_dict()
}
def consolidate(self, sep):
self.text = sep.join([word.text for word in self.words])
if self.words:
self.start = self.words[0].start
self.end = self.words[-1].end
@dataclass
class FrontData():
@@ -175,7 +276,9 @@ class ChangeSpeaker:
@dataclass
class State():
tokens: list = field(default_factory=list)
segments: list = field(default_factory=list)
last_validated_token: int = 0
last_validated_segment: int = 0 # validated means tokens speaker and transcription are validated and terminated
translation_validated_segments: list = field(default_factory=list)
translation_buffer: list = field(default_factory=list)
buffer_transcription: str = field(default_factory=Transcript)
@@ -183,4 +286,5 @@ class State():
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0
remaining_time_diarization: float = 0.0
beg_loop: Optional[int] = None
beg_loop: Optional[int] = None