custom faster-whisper/mlx whisper encoder available

This commit is contained in:
Quentin Fuxa
2025-10-23 20:33:17 +02:00
parent 0af379c465
commit 714fb3b14a
3 changed files with 60 additions and 46 deletions

View File

@@ -10,6 +10,7 @@ from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc
from pathlib import Path
logger = logging.getLogger(__name__)
import torch
@@ -22,9 +23,7 @@ try:
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print(f"""{"="*50}
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
{"="*50}""")
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
@@ -35,8 +34,24 @@ else:
except ImportError:
HAS_FASTER_WHISPER = False
def model_path_and_type(model_path):
path = Path(model_path)
compatible_whisper_mlx = False
compatible_faster_whisper = False
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
if path.is_dir():
for file in path.iterdir():
if file.is_file():
if file.name in ['weights.npz', "weights.safetensors"]:
compatible_whisper_mlx = True
elif file.suffix.lower() == '.bin':
compatible_faster_whisper = True
elif file.suffix.lower() == '.pt':
pt_path = file
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
@@ -154,8 +169,11 @@ class SimulStreamingASR():
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
self.fast_encoder = False
if self.model_dir is not None:
self.model_path = self.model_dir
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
if self.model_path:
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
elif self.model_size is not None:
model_mapping = {
'tiny': './tiny.pt',
@@ -171,10 +189,12 @@ class SimulStreamingASR():
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt')
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
self.model_name = pt_path.name.replace(".pt", "")
self.cfg = AlignAttConfig(
model_path=self.model_path,
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
segment_length=self.min_chunk_size,
frame_threshold=self.frame_threshold,
language=self.lan,
@@ -196,24 +216,27 @@ class SimulStreamingASR():
else:
self.tokenizer = None
if self.model_dir:
self.model_name = self.model_dir
self.model_path = None
else:
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
if self.model_path and compatible_whisper_mlx:
mlx_model = self.model_path
else:
mlx_model = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
self.fast_encoder = True
elif HAS_FASTER_WHISPER:
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
print('Simulstreaming will use Faster Whisper for the encoder.')
if self.model_path and compatible_faster_whisper:
fw_model = self.model_path
else:
fw_model = self.model_name
self.fw_encoder = WhisperModel(
self.model_name,
fw_model,
device='auto',
compute_type='auto',
)
@@ -224,7 +247,7 @@ class SimulStreamingASR():
def load_model(self):
whisper_model = load_model(
name=self.model_name,
name=self.model_path if self.model_path else self.model_name,
download_root=self.model_path,
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads

View File

@@ -4,26 +4,22 @@ from dataclasses import dataclass, field
from typing import Literal
@dataclass
class SimulWhisperConfig:
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
model_path: str
language: str = field(default="zh")
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)
@dataclass
class AlignAttConfig(SimulWhisperConfig):
'''Options specific to the AlignAtt policy.'''
class AlignAttConfig():
eval_data_path: str = "tmp"
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
frame_threshold: int = 4
rewind_threshold: int = 200
audio_max_len: float = 20.0
cif_ckpt_path: str = ""
never_fire: bool = False
never_fire: bool = False
language: str = field(default="zh")
nonspeech_prob: float = 0.5
audio_min_len: float = 1.0
decoder_type: Literal["greedy","beam"] = "greedy"
beam_size: int = 5
task: Literal["transcribe","translate"] = "transcribe"
tokenizer_is_multilingual: bool = False
init_prompt: str = field(default=None)
static_init_prompt: str = field(default=None)
max_context_tokens: int = field(default=None)

View File

@@ -51,20 +51,15 @@ class PaddedAlignAttWhisper:
fw_encoder=None,
) -> None:
self.log_segments = 0
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
if loaded_model:
self.model = loaded_model
else:
self.model = load_model(name=model_name, download_root=model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = loaded_model
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Model dimensions: {self.model.dims}")
self.speaker = -1
self.decode_options = DecodingOptions(
@@ -72,7 +67,7 @@ class PaddedAlignAttWhisper:
without_timestamps = True,
task=cfg.task
)
self.tokenizer_is_multilingual = not model_name.endswith(".en")
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
# self.create_tokenizer('en')
self.detected_language = cfg.language if cfg.language != "auto" else None