mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
custom faster-whisper/mlx whisper encoder available
This commit is contained in:
@@ -10,6 +10,7 @@ from .whisper import load_model, tokenizer
|
||||
from .whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import torch
|
||||
@@ -22,9 +23,7 @@ try:
|
||||
HAS_MLX_WHISPER = True
|
||||
except ImportError:
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
print(f"""{"="*50}
|
||||
MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper
|
||||
{"="*50}""")
|
||||
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
|
||||
HAS_MLX_WHISPER = False
|
||||
if HAS_MLX_WHISPER:
|
||||
HAS_FASTER_WHISPER = False
|
||||
@@ -35,8 +34,24 @@ else:
|
||||
except ImportError:
|
||||
HAS_FASTER_WHISPER = False
|
||||
|
||||
def model_path_and_type(model_path):
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
|
||||
|
||||
if path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if file.is_file():
|
||||
if file.name in ['weights.npz', "weights.safetensors"]:
|
||||
compatible_whisper_mlx = True
|
||||
elif file.suffix.lower() == '.bin':
|
||||
compatible_faster_whisper = True
|
||||
elif file.suffix.lower() == '.pt':
|
||||
pt_path = file
|
||||
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
# TOO_MANY_REPETITIONS = 3
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
@@ -154,8 +169,11 @@ class SimulStreamingASR():
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
if self.model_dir is not None:
|
||||
self.model_path = self.model_dir
|
||||
|
||||
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
|
||||
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
@@ -171,10 +189,12 @@ class SimulStreamingASR():
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt')
|
||||
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
|
||||
|
||||
self.model_name = pt_path.name.replace(".pt", "")
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
model_path=self.model_path,
|
||||
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
|
||||
segment_length=self.min_chunk_size,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.lan,
|
||||
@@ -196,24 +216,27 @@ class SimulStreamingASR():
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
if self.model_dir:
|
||||
self.model_name = self.model_dir
|
||||
self.model_path = None
|
||||
else:
|
||||
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
|
||||
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
|
||||
|
||||
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if not self.disable_fast_encoder:
|
||||
if HAS_MLX_WHISPER:
|
||||
print('Simulstreaming will use MLX whisper for a faster encoder.')
|
||||
mlx_model_name = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
|
||||
if self.model_path and compatible_whisper_mlx:
|
||||
mlx_model = self.model_path
|
||||
else:
|
||||
mlx_model = mlx_model_mapping[self.model_name]
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
self.fast_encoder = True
|
||||
elif HAS_FASTER_WHISPER:
|
||||
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self.model_path and compatible_faster_whisper:
|
||||
fw_model = self.model_path
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
self.model_name,
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
@@ -224,7 +247,7 @@ class SimulStreamingASR():
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(
|
||||
name=self.model_name,
|
||||
name=self.model_path if self.model_path else self.model_name,
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
|
||||
@@ -4,26 +4,22 @@ from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
@dataclass
|
||||
class SimulWhisperConfig:
|
||||
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
||||
model_path: str
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@dataclass
|
||||
class AlignAttConfig(SimulWhisperConfig):
|
||||
'''Options specific to the AlignAtt policy.'''
|
||||
class AlignAttConfig():
|
||||
eval_data_path: str = "tmp"
|
||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||
frame_threshold: int = 4
|
||||
rewind_threshold: int = 200
|
||||
audio_max_len: float = 20.0
|
||||
cif_ckpt_path: str = ""
|
||||
never_fire: bool = False
|
||||
never_fire: bool = False
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
tokenizer_is_multilingual: bool = False
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
@@ -51,20 +51,15 @@ class PaddedAlignAttWhisper:
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
||||
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
||||
if loaded_model:
|
||||
self.model = loaded_model
|
||||
else:
|
||||
self.model = load_model(name=model_name, download_root=model_path)
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.model = loaded_model
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
@@ -72,7 +67,7 @@ class PaddedAlignAttWhisper:
|
||||
without_timestamps = True,
|
||||
task=cfg.task
|
||||
)
|
||||
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
# self.create_tokenizer('en')
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
|
||||
Reference in New Issue
Block a user