diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index cdb4303..ee248ef 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -10,6 +10,7 @@ from .whisper import load_model, tokenizer from .whisper.audio import TOKENS_PER_SECOND import os import gc +from pathlib import Path logger = logging.getLogger(__name__) import torch @@ -22,9 +23,7 @@ try: HAS_MLX_WHISPER = True except ImportError: if platform.system() == "Darwin" and platform.machine() == "arm64": - print(f"""{"="*50} -MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper -{"="*50}""") + print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""") HAS_MLX_WHISPER = False if HAS_MLX_WHISPER: HAS_FASTER_WHISPER = False @@ -35,8 +34,24 @@ else: except ImportError: HAS_FASTER_WHISPER = False +def model_path_and_type(model_path): + path = Path(model_path) + + compatible_whisper_mlx = False + compatible_faster_whisper = False + pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None + + if path.is_dir(): + for file in path.iterdir(): + if file.is_file(): + if file.name in ['weights.npz', "weights.safetensors"]: + compatible_whisper_mlx = True + elif file.suffix.lower() == '.bin': + compatible_faster_whisper = True + elif file.suffix.lower() == '.pt': + pt_path = file + return pt_path, compatible_whisper_mlx, compatible_faster_whisper -# TOO_MANY_REPETITIONS = 3 class SimulStreamingOnlineProcessor: SAMPLING_RATE = 16000 @@ -154,8 +169,11 @@ class SimulStreamingASR(): self.decoder_type = 'greedy' if self.beams == 1 else 'beam' self.fast_encoder = False - if self.model_dir is not None: - self.model_path = self.model_dir + + pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True + if self.model_path: + pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path) + elif self.model_size is not None: model_mapping = { 'tiny': './tiny.pt', @@ -171,10 +189,12 @@ class SimulStreamingASR(): 'large-v3': './large-v3.pt', 'large': './large-v3.pt' } - self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt') + pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt')) + + self.model_name = pt_path.name.replace(".pt", "") self.cfg = AlignAttConfig( - model_path=self.model_path, + tokenizer_is_multilingual= not self.model_name.endswith(".en"), segment_length=self.min_chunk_size, frame_threshold=self.frame_threshold, language=self.lan, @@ -196,24 +216,27 @@ class SimulStreamingASR(): else: self.tokenizer = None - if self.model_dir: - self.model_name = self.model_dir - self.model_path = None - else: - self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") - self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) + + self.mlx_encoder, self.fw_encoder = None, None if not self.disable_fast_encoder: if HAS_MLX_WHISPER: print('Simulstreaming will use MLX whisper for a faster encoder.') - mlx_model_name = mlx_model_mapping[self.model_name] - self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name) + if self.model_path and compatible_whisper_mlx: + mlx_model = self.model_path + else: + mlx_model = mlx_model_mapping[self.model_name] + self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) self.fast_encoder = True - elif HAS_FASTER_WHISPER: + elif HAS_FASTER_WHISPER and compatible_faster_whisper: print('Simulstreaming will use Faster Whisper for the encoder.') + if self.model_path and compatible_faster_whisper: + fw_model = self.model_path + else: + fw_model = self.model_name self.fw_encoder = WhisperModel( - self.model_name, + fw_model, device='auto', compute_type='auto', ) @@ -224,7 +247,7 @@ class SimulStreamingASR(): def load_model(self): whisper_model = load_model( - name=self.model_name, + name=self.model_path if self.model_path else self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder, custom_alignment_heads=self.custom_alignment_heads diff --git a/whisperlivekit/simul_whisper/config.py b/whisperlivekit/simul_whisper/config.py index bfc2f31..08f72c1 100644 --- a/whisperlivekit/simul_whisper/config.py +++ b/whisperlivekit/simul_whisper/config.py @@ -4,26 +4,22 @@ from dataclasses import dataclass, field from typing import Literal @dataclass -class SimulWhisperConfig: - '''Options that are common for all simul policies that could be implemented in SimulWhisper.''' - model_path: str - language: str = field(default="zh") - nonspeech_prob: float = 0.5 - audio_min_len: float = 1.0 - decoder_type: Literal["greedy","beam"] = "greedy" - beam_size: int = 5 - task: Literal["transcribe","translate"] = "transcribe" - init_prompt: str = field(default=None) - static_init_prompt: str = field(default=None) - max_context_tokens: int = field(default=None) - -@dataclass -class AlignAttConfig(SimulWhisperConfig): - '''Options specific to the AlignAtt policy.''' +class AlignAttConfig(): eval_data_path: str = "tmp" segment_length: float = field(default=1.0, metadata = {"help": "in second"}) frame_threshold: int = 4 rewind_threshold: int = 200 audio_max_len: float = 20.0 cif_ckpt_path: str = "" - never_fire: bool = False \ No newline at end of file + never_fire: bool = False + language: str = field(default="zh") + nonspeech_prob: float = 0.5 + audio_min_len: float = 1.0 + decoder_type: Literal["greedy","beam"] = "greedy" + beam_size: int = 5 + task: Literal["transcribe","translate"] = "transcribe" + tokenizer_is_multilingual: bool = False + init_prompt: str = field(default=None) + static_init_prompt: str = field(default=None) + max_context_tokens: int = field(default=None) + \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 250fa5b..73a9c42 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -51,20 +51,15 @@ class PaddedAlignAttWhisper: fw_encoder=None, ) -> None: self.log_segments = 0 - model_name = os.path.basename(cfg.model_path).replace(".pt", "") - model_path = os.path.dirname(os.path.abspath(cfg.model_path)) - if loaded_model: - self.model = loaded_model - else: - self.model = load_model(name=model_name, download_root=model_path) - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - + self.model = loaded_model self.mlx_encoder = mlx_encoder self.fw_encoder = fw_encoder if fw_encoder: self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) - + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + logger.info(f"Model dimensions: {self.model.dims}") self.speaker = -1 self.decode_options = DecodingOptions( @@ -72,7 +67,7 @@ class PaddedAlignAttWhisper: without_timestamps = True, task=cfg.task ) - self.tokenizer_is_multilingual = not model_name.endswith(".en") + self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual self.create_tokenizer(cfg.language if cfg.language != "auto" else None) # self.create_tokenizer('en') self.detected_language = cfg.language if cfg.language != "auto" else None