mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-05-06 15:16:27 +00:00
Ruff lint cleanup
This commit is contained in:
@@ -130,18 +130,18 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat
|
||||
available_ops = [15, 16]
|
||||
if opset_version not in available_ops:
|
||||
raise ValueError(f'Unsupported ONNX opset_version: {opset_version}. Available: {available_ops}')
|
||||
|
||||
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'silero_vad_models'
|
||||
|
||||
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
@@ -149,7 +149,7 @@ def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Pat
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
@@ -169,9 +169,9 @@ def load_jit_vad(model_path: str = None):
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'silero_vad_models'
|
||||
model_name = 'silero_vad.jit'
|
||||
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
@@ -181,17 +181,17 @@ def load_jit_vad(model_path: str = None):
|
||||
model_path = Path(model_path)
|
||||
|
||||
model = init_jit_model(str(model_path))
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class VADIterator:
|
||||
"""
|
||||
Voice Activity Detection iterator for streaming audio.
|
||||
|
||||
|
||||
This is the Silero VAD v6 implementation.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
@@ -319,8 +319,8 @@ if __name__ == "__main__":
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
|
||||
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 511 samples: {result}")
|
||||
print(f" 511 samples: {result}")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Abstract base class for AlignAtt streaming decoders (PyTorch & MLX)."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||
@@ -151,7 +150,7 @@ class AlignAttBase(ABC):
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.state.cumulative_time_offset = 0.0
|
||||
|
||||
@@ -1,31 +1,27 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||
from .mlx import MLXAlignAtt
|
||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||
else:
|
||||
mlx_model_mapping = {}
|
||||
MLXAlignAtt = None
|
||||
@@ -47,7 +43,7 @@ class SimulStreamingOnlineProcessor:
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.model = self._create_alignatt()
|
||||
|
||||
|
||||
if asr.tokenizer:
|
||||
self.model.tokenizer = asr.tokenizer
|
||||
self.model.state.tokenizer = asr.tokenizer
|
||||
@@ -99,7 +95,7 @@ class SimulStreamingOnlineProcessor:
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.model.global_time_offset = change_speaker.start
|
||||
|
||||
|
||||
def get_buffer(self):
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
return concat_buffer
|
||||
@@ -107,19 +103,19 @@ class SimulStreamingOnlineProcessor:
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
|
||||
|
||||
if not timestamped_words:
|
||||
return [], self.end
|
||||
|
||||
|
||||
if self.model.cfg.language == "auto" and timestamped_words[0].detected_language is None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
except Exception as e:
|
||||
@@ -156,7 +152,7 @@ class SimulStreamingASR:
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -169,20 +165,20 @@ class SimulStreamingASR:
|
||||
self.use_full_mlx = getattr(self, "use_full_mlx", False)
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||
|
||||
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
|
||||
|
||||
model_info = detect_model_format(resolved_model_path)
|
||||
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||
|
||||
|
||||
if not self.use_full_mlx and not model_info.has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
)
|
||||
self.model_name = resolved_model_path.name if resolved_model_path.is_dir() else resolved_model_path.stem
|
||||
elif self.model_size is not None:
|
||||
self.model_name = self.model_size
|
||||
@@ -199,14 +195,14 @@ class SimulStreamingASR:
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
|
||||
# MLX full decoder disabled by default — MLXAlignAtt has known issues
|
||||
# with token generation after punctuation. Users can opt-in with
|
||||
# --use-full-mlx if they want to test it.
|
||||
# if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
||||
# if not hasattr(self, '_full_mlx_disabled'):
|
||||
# self.use_full_mlx = True
|
||||
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
segment_length=self.min_chunk_size,
|
||||
@@ -222,8 +218,8 @@ class SimulStreamingASR:
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.direct_english_translation:
|
||||
self.tokenizer = self.set_translate_task()
|
||||
@@ -232,7 +228,7 @@ class SimulStreamingASR:
|
||||
|
||||
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
|
||||
self.shared_model = None
|
||||
|
||||
|
||||
if self.use_full_mlx and HAS_MLX_WHISPER:
|
||||
logger.info('MLX Whisper backend used.')
|
||||
if self._resolved_model_path is not None:
|
||||
@@ -259,7 +255,7 @@ class SimulStreamingASR:
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||
self.shared_model = self.load_model()
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('SimulStreaming will use Faster Whisper for the encoder.')
|
||||
logger.info('SimulStreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
@@ -272,7 +268,7 @@ class SimulStreamingASR:
|
||||
self.shared_model = self.load_model()
|
||||
else:
|
||||
self.shared_model = self.load_model()
|
||||
|
||||
|
||||
def _warmup_mlx_model(self):
|
||||
"""Warmup the full MLX model."""
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
|
||||
@@ -19,14 +19,14 @@ class BeamPyTorchInference(PyTorchInference):
|
||||
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
||||
|
||||
def logits(
|
||||
self,
|
||||
tokens: Tensor,
|
||||
self,
|
||||
tokens: Tensor,
|
||||
audio_features: Tensor,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
"""Get logits, optionally returning cross-attention weights."""
|
||||
return self.model.decoder(
|
||||
tokens, audio_features,
|
||||
tokens, audio_features,
|
||||
kv_cache=self.kv_cache,
|
||||
return_cross_attn=return_cross_attn,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -21,4 +21,3 @@ class AlignAttConfig():
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -7,23 +8,23 @@ import torch
|
||||
class DecoderState:
|
||||
|
||||
kv_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
|
||||
|
||||
tokenizer: Any = None
|
||||
detected_language: Optional[str] = None
|
||||
reset_tokenizer_to_auto_next_call: bool = False
|
||||
|
||||
|
||||
tokens: List[torch.Tensor] = field(default_factory=list)
|
||||
initial_tokens: Optional[torch.Tensor] = None
|
||||
initial_token_length: int = 0
|
||||
sot_index: int = 0
|
||||
|
||||
|
||||
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||
num_align_heads: int = 0
|
||||
|
||||
|
||||
segments: List[torch.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
context: Any = None
|
||||
|
||||
|
||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||
pending_retries: int = 0
|
||||
|
||||
@@ -31,21 +32,21 @@ class DecoderState:
|
||||
cumulative_time_offset: float = 0.0
|
||||
first_timestamp: Optional[float] = None
|
||||
last_attend_frame: int = 0
|
||||
|
||||
|
||||
speaker: int = -1
|
||||
log_segments: int = 0
|
||||
|
||||
|
||||
CIFLinear: Optional[torch.nn.Module] = None
|
||||
always_fire: bool = False
|
||||
never_fire: bool = False
|
||||
|
||||
|
||||
suppress_tokens_fn: Any = None
|
||||
|
||||
|
||||
token_decoder: Any = None
|
||||
decoder_type: str = "greedy"
|
||||
|
||||
|
||||
inference: Any = None
|
||||
|
||||
|
||||
def clean_cache(self):
|
||||
"""Clean the kv_cache after each inference step."""
|
||||
# Explicitly delete tensor references to free GPU memory
|
||||
@@ -68,11 +69,11 @@ class DecoderState:
|
||||
self.inference.kv_cache = {}
|
||||
if self.token_decoder is not None:
|
||||
self.token_decoder.reset()
|
||||
|
||||
|
||||
def reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Reset transient state for a new segment.
|
||||
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
@@ -85,7 +86,7 @@ class DecoderState:
|
||||
def full_reset(self, rewind_threshold: int = 200):
|
||||
"""
|
||||
Full reset including audio segments and tokens.
|
||||
|
||||
|
||||
Args:
|
||||
rewind_threshold: Value for resetting last_attend_frame
|
||||
"""
|
||||
|
||||
@@ -46,7 +46,7 @@ def resize(alphas, target_lengths, threshold=0.999):
|
||||
_alphas[x] = _alphas[x] * 0.5 + mean * mask
|
||||
|
||||
return _alphas, _num
|
||||
|
||||
|
||||
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
|
||||
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
|
||||
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
|
||||
@@ -62,4 +62,4 @@ def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
|
||||
if important_positions.numel() == 0:
|
||||
return False
|
||||
else:
|
||||
return important_positions[0] >= content_mel_len-2
|
||||
return important_positions[0] >= content_mel_len-2
|
||||
|
||||
@@ -13,21 +13,21 @@ class MLXDecoderState:
|
||||
"""
|
||||
|
||||
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
|
||||
|
||||
|
||||
tokenizer: Any = None
|
||||
detected_language: Optional[str] = None
|
||||
reset_tokenizer_to_auto_next_call: bool = False
|
||||
|
||||
|
||||
tokens: List[mx.array] = field(default_factory=list)
|
||||
initial_tokens: Optional[mx.array] = None
|
||||
initial_token_length: int = 0
|
||||
sot_index: int = 0
|
||||
sot_index: int = 0
|
||||
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||
num_align_heads: int = 0
|
||||
num_align_heads: int = 0
|
||||
segments: List[np.ndarray] = field(default_factory=list)
|
||||
|
||||
|
||||
context: Any = None
|
||||
|
||||
|
||||
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||
pending_retries: int = 0
|
||||
|
||||
@@ -35,27 +35,27 @@ class MLXDecoderState:
|
||||
cumulative_time_offset: float = 0.0
|
||||
first_timestamp: Optional[float] = None
|
||||
last_attend_frame: int = 0
|
||||
|
||||
|
||||
speaker: int = -1
|
||||
log_segments: int = 0
|
||||
log_segments: int = 0
|
||||
cif_weights: Optional[mx.array] = None
|
||||
always_fire: bool = False
|
||||
never_fire: bool = False
|
||||
|
||||
|
||||
suppress_tokens: Optional[Tuple[int, ...]] = None
|
||||
|
||||
|
||||
token_decoder: Any = None
|
||||
decoder_type: str = "greedy"
|
||||
|
||||
|
||||
inference: Any = None
|
||||
|
||||
|
||||
def clean_cache(self):
|
||||
self.kv_cache = None
|
||||
if self.decoder_type == "beam" and self.inference is not None:
|
||||
self.inference.kv_cache = None
|
||||
if self.token_decoder is not None:
|
||||
self.token_decoder.reset()
|
||||
|
||||
|
||||
def reset(self, rewind_threshold: int = 200):
|
||||
self.last_attend_frame = -rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
|
||||
class MLXGreedyDecoder:
|
||||
"""Greedy decoder using MLX operations."""
|
||||
|
||||
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
@@ -33,18 +33,18 @@ class MLXGreedyDecoder:
|
||||
else:
|
||||
probs = mx.softmax(logits / self.temperature, axis=-1)
|
||||
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
|
||||
|
||||
|
||||
logprobs = mx.softmax(logits, axis=-1)
|
||||
logprobs = mx.log(logprobs + 1e-10)
|
||||
logprobs = mx.log(logprobs + 1e-10)
|
||||
batch_size = logprobs.shape[0]
|
||||
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
|
||||
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
|
||||
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
|
||||
sum_logprobs = sum_logprobs + current_logprobs * mask
|
||||
sum_logprobs = sum_logprobs + current_logprobs * mask
|
||||
eot_mask = (tokens[:, -1] == self.eot)
|
||||
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
|
||||
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
|
||||
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
|
||||
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
|
||||
completed = bool(mx.all(tokens[:, -1] == self.eot))
|
||||
|
||||
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||
@@ -56,7 +56,7 @@ class MLXGreedyDecoder:
|
||||
|
||||
class MLXBeamSearchDecoder:
|
||||
"""Beam search decoder using MLX operations."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beam_size: int,
|
||||
@@ -100,21 +100,21 @@ class MLXBeamSearchDecoder:
|
||||
if self.finished_sequences is None:
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
logprobs = mx.softmax(logits, axis=-1)
|
||||
logprobs = mx.log(logprobs + 1e-10)
|
||||
logprobs = mx.log(logprobs + 1e-10)
|
||||
logprobs_np = np.array(logprobs)
|
||||
tokens_np = np.array(tokens)
|
||||
sum_logprobs_np = np.array(sum_logprobs)
|
||||
|
||||
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
new_sum_logprobs = []
|
||||
|
||||
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens_np[idx].tolist()
|
||||
prefix = tokens_np[idx].tolist()
|
||||
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
|
||||
|
||||
|
||||
for token_idx in top_k_indices:
|
||||
logprob = logprobs_np[idx, token_idx]
|
||||
new_logprob = sum_logprobs_np[idx] + logprob
|
||||
@@ -136,7 +136,7 @@ class MLXBeamSearchDecoder:
|
||||
|
||||
finished_sequences.append(finished)
|
||||
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
|
||||
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
|
||||
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
|
||||
self.inference.rearrange_kv_cache(source_indices)
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(
|
||||
@@ -150,14 +150,14 @@ class MLXBeamSearchDecoder:
|
||||
len(sequences) >= self.max_candidates
|
||||
for sequences in self.finished_sequences
|
||||
)
|
||||
|
||||
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
|
||||
"""Finalize beam search by selecting best sequences."""
|
||||
preceding_tokens_np = np.array(preceding_tokens)
|
||||
sum_logprobs_np = np.array(sum_logprobs)
|
||||
|
||||
|
||||
n_audio = preceding_tokens_np.shape[0] // self.beam_size
|
||||
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
|
||||
sum_logprobs_list: List[float] = [0.0] * n_audio
|
||||
@@ -181,34 +181,34 @@ class MLXBeamSearchDecoder:
|
||||
|
||||
class MLXInference:
|
||||
"""MLX inference wrapper for beam search KV cache management."""
|
||||
|
||||
|
||||
def __init__(self, model, initial_token_length: int):
|
||||
self.model = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = None
|
||||
|
||||
|
||||
def rearrange_kv_cache(self, source_indices: List[int]):
|
||||
"""Rearrange KV cache based on beam search source indices."""
|
||||
if self.kv_cache is None:
|
||||
return
|
||||
|
||||
|
||||
if source_indices == list(range(len(source_indices))):
|
||||
return
|
||||
|
||||
|
||||
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
|
||||
|
||||
|
||||
new_cache = []
|
||||
for layer_cache in self.kv_cache:
|
||||
(k, v), (cross_k, cross_v) = layer_cache
|
||||
(k, v), (cross_k, cross_v) = layer_cache
|
||||
new_k = k[source_indices_mx]
|
||||
new_v = v[source_indices_mx]
|
||||
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
|
||||
|
||||
|
||||
self.kv_cache = new_cache
|
||||
|
||||
|
||||
def logits(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
self,
|
||||
tokens: mx.array,
|
||||
audio_features: mx.array,
|
||||
) -> Tuple[mx.array, List]:
|
||||
"""Get logits from decoder with KV cache."""
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
@@ -15,7 +14,6 @@ from ..config import AlignAttConfig
|
||||
from .decoder_state import MLXDecoderState
|
||||
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -41,17 +41,17 @@ def load_mlx_encoder(
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
|
||||
# we only want to load the encoder weights here.
|
||||
# Size examples: for tiny.en,
|
||||
# Size examples: for tiny.en,
|
||||
# Decoder weights: 59110771 bytes
|
||||
# Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
|
||||
encoder_weights = {}
|
||||
encoder_weights['encoder'] = weights['encoder']
|
||||
del(weights)
|
||||
|
||||
|
||||
|
||||
|
||||
model.update(encoder_weights)
|
||||
@@ -89,7 +89,7 @@ def load_mlx_model(
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
return model
|
||||
|
||||
@@ -6,13 +6,9 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from whisperlivekit.backend_support import (faster_backend_available,
|
||||
mlx_backend_available)
|
||||
from whisperlivekit.whisper.audio import (N_FRAMES, N_SAMPLES,
|
||||
TOKENS_PER_SECOND,
|
||||
log_mel_spectrogram, pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (BeamSearchDecoder, GreedyDecoder,
|
||||
SuppressTokens)
|
||||
from whisperlivekit.backend_support import faster_backend_available, mlx_backend_available
|
||||
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND, log_mel_spectrogram, pad_or_trim
|
||||
from whisperlivekit.whisper.decoding import BeamSearchDecoder, GreedyDecoder, SuppressTokens
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
|
||||
from .align_att_base import DEC_PAD, AlignAttBase
|
||||
@@ -25,8 +21,7 @@ from .token_buffer import TokenBuffer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import \
|
||||
log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
if faster_backend_available():
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,7 +16,7 @@ class TokenBuffer:
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||
|
||||
def as_tensor(self, device=None):
|
||||
@@ -26,7 +25,7 @@ class TokenBuffer:
|
||||
if device is None:
|
||||
raise ValueError("Device is not set.")
|
||||
tok_ids = self.as_token_ids()
|
||||
return torch.tensor(tok_ids,
|
||||
return torch.tensor(tok_ids,
|
||||
dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
def as_tensor_beam(self, beam, device=None):
|
||||
@@ -44,7 +43,7 @@ class TokenBuffer:
|
||||
@staticmethod
|
||||
def from_text(text, *a, **kw):
|
||||
return TokenBuffer(*a, text=text, **kw)
|
||||
|
||||
|
||||
def is_empty(self):
|
||||
return self.text is None or self.text == ""
|
||||
|
||||
|
||||
@@ -11,10 +11,8 @@ import torch
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from whisperlivekit.whisper.audio import (load_audio, log_mel_spectrogram,
|
||||
pad_or_trim)
|
||||
from whisperlivekit.whisper.decoding import (DecodingOptions, DecodingResult,
|
||||
decode, detect_language)
|
||||
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
@@ -266,7 +264,7 @@ def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, to
|
||||
for key, value in state_dict.items():
|
||||
if key == "alignment_heads":
|
||||
continue
|
||||
|
||||
|
||||
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
|
||||
converted[new_key] = value
|
||||
|
||||
@@ -310,13 +308,13 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
if not lora_path:
|
||||
return None
|
||||
|
||||
|
||||
# Check if it's already a valid local path
|
||||
if os.path.isdir(lora_path):
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if os.path.isfile(config_path):
|
||||
return lora_path
|
||||
|
||||
|
||||
# Try to download from HuggingFace Hub
|
||||
if "/" in lora_path:
|
||||
try:
|
||||
@@ -330,7 +328,7 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find LoRA adapter at local path or HuggingFace Hub: {lora_path}. Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"LoRA path '{lora_path}' is not a valid local directory or HuggingFace repo ID."
|
||||
)
|
||||
@@ -339,7 +337,7 @@ def _resolve_lora_path(lora_path: Optional[str]) -> Optional[str]:
|
||||
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||
if not lora_path:
|
||||
return
|
||||
|
||||
|
||||
# Resolve path (handles HuggingFace Hub download)
|
||||
lora_path = _resolve_lora_path(lora_path)
|
||||
if not lora_path:
|
||||
@@ -410,10 +408,10 @@ def _load_checkpoint(
|
||||
if checkpoint_bytes is not None:
|
||||
with io.BytesIO(checkpoint_bytes) as fp:
|
||||
return torch.load(fp, map_location=device)
|
||||
|
||||
|
||||
file_path = Path(file_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
|
||||
if suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
@@ -444,7 +442,7 @@ def _load_sharded_checkpoint(
|
||||
"""
|
||||
merged_state_dict = {}
|
||||
first_suffix = shard_files[0].suffix.lower()
|
||||
|
||||
|
||||
if first_suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
@@ -461,7 +459,7 @@ def _load_sharded_checkpoint(
|
||||
shard_dict = torch.load(fp, map_location=device)
|
||||
if isinstance(shard_dict, dict):
|
||||
merged_state_dict.update(shard_dict)
|
||||
|
||||
|
||||
return merged_state_dict
|
||||
|
||||
|
||||
@@ -505,10 +503,10 @@ def load_model(
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
|
||||
checkpoint = None
|
||||
model_path_for_config = name # Used to find config.json for dims inference
|
||||
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
if in_memory:
|
||||
@@ -525,13 +523,13 @@ def load_model(
|
||||
model_path_for_config = name
|
||||
elif os.path.isdir(name):
|
||||
model_info = detect_model_format(name)
|
||||
|
||||
|
||||
if not model_info.has_pytorch:
|
||||
raise RuntimeError(
|
||||
f"No PyTorch checkpoint found in directory {name}. "
|
||||
f"Expected .pt, .bin, or .safetensors file(s)."
|
||||
)
|
||||
|
||||
|
||||
if model_info.is_sharded:
|
||||
checkpoint = _load_sharded_checkpoint(model_info.pytorch_files, device)
|
||||
else:
|
||||
@@ -547,7 +545,7 @@ def load_model(
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
|
||||
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
@@ -557,10 +555,10 @@ def load_model(
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
|
||||
if alignment_heads is None and "alignment_heads" in state_dict:
|
||||
alignment_heads = state_dict["alignment_heads"]
|
||||
|
||||
|
||||
state_dict = _convert_hf_state_dict(state_dict)
|
||||
state_dict = _convert_mlx_state_dict(state_dict)
|
||||
_apply_lora_adapter(state_dict, lora_path)
|
||||
@@ -578,10 +576,10 @@ def load_model(
|
||||
state_dict = checkpoint
|
||||
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
|
||||
if decoder_only:
|
||||
state_dict = {
|
||||
k: v for k, v in state_dict.items()
|
||||
k: v for k, v in state_dict.items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
@@ -604,7 +602,7 @@ def convert_encoder_to_coreml(
|
||||
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
|
||||
precision = "float16",
|
||||
):
|
||||
|
||||
|
||||
import coremltools as ct
|
||||
model = load_model(model_name, device="cpu", decoder_only=False)
|
||||
encoder = model.encoder.eval().cpu()
|
||||
@@ -639,4 +637,4 @@ def convert_encoder_to_coreml(
|
||||
return output_path
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")
|
||||
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||
Tuple, Union)
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -175,7 +175,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, n_state: int, n_head: int, cross_attention: bool = False,
|
||||
self, n_state: int, n_head: int, cross_attention: bool = False,
|
||||
cache_id: str = "", n_text_ctx: int = 448
|
||||
):
|
||||
super().__init__()
|
||||
@@ -267,7 +267,7 @@ class TextDecoder(nn.Module):
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
n_state, n_head, cross_attention=True,
|
||||
n_state, n_head, cross_attention=True,
|
||||
cache_id=f"dec_layer{i}", n_text_ctx=n_ctx
|
||||
)
|
||||
for i in range(n_layer)
|
||||
@@ -279,9 +279,9 @@ class TextDecoder(nn.Module):
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Tensor,
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Tensor,
|
||||
kv_cache: Optional[dict] = None,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
@@ -309,7 +309,7 @@ class TextDecoder(nn.Module):
|
||||
first_self_attn_key = self.blocks[0].attn.key_cache_id
|
||||
if first_self_attn_key in kv_cache:
|
||||
offset = kv_cache[first_self_attn_key].shape[1]
|
||||
|
||||
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
@@ -336,7 +336,7 @@ class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
|
||||
if not decoder_only:
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
@@ -373,15 +373,15 @@ class Whisper(nn.Module):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
audio_features: torch.Tensor,
|
||||
kv_cache: Optional[dict] = None,
|
||||
return_cross_attn: bool = False,
|
||||
):
|
||||
return self.decoder(
|
||||
tokens, audio_features,
|
||||
kv_cache=kv_cache,
|
||||
tokens, audio_features,
|
||||
kv_cache=kv_cache,
|
||||
return_cross_attn=return_cross_attn
|
||||
)
|
||||
|
||||
|
||||
@@ -8,13 +8,11 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import (FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES,
|
||||
SAMPLE_RATE, log_mel_spectrogram, pad_or_trim)
|
||||
from .audio import FRAMES_PER_SECOND, HOP_LENGTH, N_FRAMES, N_SAMPLES, SAMPLE_RATE, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (exact_div, format_timestamp, get_end, get_writer,
|
||||
make_safe, optional_float, optional_int, str2bool)
|
||||
from .utils import exact_div, format_timestamp, get_end, get_writer, make_safe, optional_float, optional_int, str2bool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
@@ -6,9 +6,10 @@ Everything else is just efficiency.
|
||||
@karpathy
|
||||
"""
|
||||
|
||||
import os # os.path.exists
|
||||
import math # math.log, math.exp
|
||||
import random # random.seed, random.choices, random.gauss, random.shuffle
|
||||
import math # math.log, math.exp
|
||||
import os # os.path.exists
|
||||
import random # random.seed, random.choices, random.gauss, random.shuffle
|
||||
|
||||
random.seed(42) # Let there be order among chaos
|
||||
|
||||
# Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
|
||||
@@ -197,4 +198,4 @@ for sample_idx in range(20):
|
||||
if token_id == BOS:
|
||||
break
|
||||
sample.append(uchars[token_id])
|
||||
print(f"sample {sample_idx+1:2d}: {''.join(sample)}")
|
||||
print(f"sample {sample_idx+1:2d}: {''.join(sample)}")
|
||||
|
||||
Reference in New Issue
Block a user