From 8cbaeecc75cfb144a87bd3255bcf7e2a823c2f31 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 27 Sep 2025 11:04:00 +0200 Subject: [PATCH] cutom alignment heads parameter for custom models --- README.md | 2 + whisperlivekit/audio_processor.py | 3 +- whisperlivekit/core.py | 158 ++++++++---------- whisperlivekit/parse_args.py | 9 + whisperlivekit/simul_whisper/backend.py | 57 +++---- .../simul_whisper/whisper/__init__.py | 11 +- whisperlivekit/translation/translation.py | 20 +-- .../whisper_streaming_custom/backends.py | 32 ++-- .../whisper_streaming_custom/online_asr.py | 10 +- .../whisper_online.py | 54 +++--- 10 files changed, 179 insertions(+), 177 deletions(-) diff --git a/README.md b/README.md index 1f15dcd..88a219c 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,7 @@ async def websocket_endpoint(websocket: WebSocket): | Parameter | Description | Default | |-----------|-------------|---------| | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md) | `small` | +| `--model-dir` | Directory containing Whisper model.bin and other files. Overrides `--model`. | `None` | | `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--target-language` | If sets, activates translation using NLLB. Ex: `fr`. [118 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/translation/mapping_languages.py). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` | | `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` | @@ -169,6 +170,7 @@ async def websocket_endpoint(websocket: WebSocket): | SimulStreaming backend options | Description | Default | |-----------|-------------|---------| | `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` | +| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 20d19eb..97a1c76 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -72,7 +72,6 @@ class AudioProcessor: # Models and processing self.asr = models.asr - self.tokenizer = models.tokenizer self.vac_model = models.vac_model if self.args.vac: self.vac = FixedVADIterator(models.vac_model) @@ -109,7 +108,7 @@ class AudioProcessor: self.diarization = None if self.args.transcription: - self.transcription = online_factory(self.args, models.asr, models.tokenizer) + self.transcription = online_factory(self.args, models.asr) self.sep = self.transcription.asr.sep if self.args.diarization: self.diarization = online_diarization_factory(self.args, models.diarization_model) diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index b4ef8d8..e6e9893 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -4,10 +4,15 @@ try: except ImportError: from .whisper_streaming_custom.whisper_online import backend_factory from .whisper_streaming_custom.online_asr import OnlineASRProcessor -from whisperlivekit.warmup import warmup_asr from argparse import Namespace import sys +def update_with_kwargs(_dict, kwargs): + _dict.update({ + k: v for k, v in kwargs.items() if k in _dict + }) + return _dict + class TranscriptionEngine: _instance = None _initialized = False @@ -21,20 +26,12 @@ class TranscriptionEngine: if TranscriptionEngine._initialized: return - defaults = { + global_params = { "host": "localhost", "port": 8000, - "warmup_file": None, "diarization": False, "punctuation_split": False, - "min_chunk_size": 0.5, - "model": "tiny", - "model_cache_dir": None, - "model_dir": None, - "lan": "auto", - "task": "transcribe", "target_language": "", - "backend": "faster-whisper", "vac": True, "vac_chunk_size": 0.04, "log_level": "DEBUG", @@ -43,54 +40,31 @@ class TranscriptionEngine: "transcription": True, "vad": True, "pcm_input": False, - - # whisperstreaming params: - "buffer_trimming": "segment", - "confidence_validation": False, - "buffer_trimming_sec": 15, - - # simulstreaming params: - "disable_fast_encoder": False, - "frame_threshold": 25, - "beams": 1, - "decoder_type": None, - "audio_max_len": 20.0, - "audio_min_len": 0.0, - "cif_ckpt_path": None, - "never_fire": False, - "init_prompt": None, - "static_init_prompt": None, - "max_context_tokens": None, - "model_path": './base.pt', - "diarization_backend": "sortformer", - - # diarization params: "disable_punctuation_split" : False, - "segmentation_model": "pyannote/segmentation-3.0", - "embedding_model": "pyannote/embedding", - - # translation params: - "nllb_backend": "ctranslate2", - "nllb_size": "600M" + "diarization_backend": "sortformer", } + global_params = update_with_kwargs(global_params, kwargs) - config_dict = {**defaults, **kwargs} + transcription_common_params = { + "backend": "simulstreaming", + "warmup_file": None, + "min_chunk_size": 0.5, + "model_size": "tiny", + "model_cache_dir": None, + "model_dir": None, + "lan": "auto", + "task": "transcribe", + } + transcription_common_params = update_with_kwargs(transcription_common_params, kwargs) if 'no_transcription' in kwargs: - config_dict['transcription'] = not kwargs['no_transcription'] + global_params['transcription'] = not global_params['no_transcription'] if 'no_vad' in kwargs: - config_dict['vad'] = not kwargs['no_vad'] + global_params['vad'] = not kwargs['no_vad'] if 'no_vac' in kwargs: - config_dict['vac'] = not kwargs['no_vac'] - - config_dict.pop('no_transcription', None) - config_dict.pop('no_vad', None) + global_params['vac'] = not kwargs['no_vac'] - if 'language' in kwargs: - config_dict['lan'] = kwargs['language'] - config_dict.pop('language', None) - - self.args = Namespace(**config_dict) + self.args = Namespace(**{**global_params, **transcription_common_params}) self.asr = None self.tokenizer = None @@ -104,44 +78,57 @@ class TranscriptionEngine: if self.args.transcription: if self.args.backend == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingASR - self.tokenizer = None - simulstreaming_kwargs = {} - for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', - 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', - 'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']: - if hasattr(self.args, attr): - simulstreaming_kwargs[attr] = getattr(self.args, attr) - - # Add segment_length from min_chunk_size - simulstreaming_kwargs['segment_length'] = getattr(self.args, 'min_chunk_size', 0.5) - simulstreaming_kwargs['task'] = self.args.task - size = self.args.model + simulstreaming_params = { + "disable_fast_encoder": False, + "custom_alignment_heads": None, + "frame_threshold": 25, + "beams": 1, + "decoder_type": None, + "audio_max_len": 20.0, + "audio_min_len": 0.0, + "cif_ckpt_path": None, + "never_fire": False, + "init_prompt": None, + "static_init_prompt": None, + "max_context_tokens": None, + "model_path": './base.pt', + "preload_model_count": 1, + } + simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs) + + self.tokenizer = None self.asr = SimulStreamingASR( - modelsize=size, - lan=self.args.lan, - cache_dir=getattr(self.args, 'model_cache_dir', None), - model_dir=getattr(self.args, 'model_dir', None), - **simulstreaming_kwargs + **transcription_common_params, **simulstreaming_params ) - else: - self.asr, self.tokenizer = backend_factory(self.args) - warmup_asr(self.asr, self.args.warmup_file) #for simulstreaming, warmup should be done in the online class not here + + whisperstreaming_params = { + "buffer_trimming": "segment", + "confidence_validation": False, + "buffer_trimming_sec": 15, + } + whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs) + + self.asr = backend_factory( + **transcription_common_params, **whisperstreaming_params + ) if self.args.diarization: if self.args.diarization_backend == "diart": from whisperlivekit.diarization.diart_backend import DiartDiarization + diart_params = { + "segmentation_model": "pyannote/segmentation-3.0", + "embedding_model": "pyannote/embedding", + } + diart_params = update_with_kwargs(diart_params, kwargs) self.diarization_model = DiartDiarization( block_duration=self.args.min_chunk_size, - segmentation_model_name=self.args.segmentation_model, - embedding_model_name=self.args.embedding_model + **diart_params ) elif self.args.diarization_backend == "sortformer": from whisperlivekit.diarization.sortformer_backend import SortformerDiarization self.diarization_model = SortformerDiarization() - else: - raise ValueError(f"Unknown diarization backend: {self.args.diarization_backend}") self.translation_model = None if self.args.target_language: @@ -149,26 +136,21 @@ class TranscriptionEngine: raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming') else: from whisperlivekit.translation.translation import load_model - self.translation_model = load_model([self.args.lan], backend=self.args.nllb_backend, model_size=self.args.nllb_size) #in the future we want to handle different languages for different speakers + translation_params = { + "nllb_backend": "ctranslate2", + "nllb_size": "600M" + } + translation_params = update_with_kwargs(translation_params, kwargs) + self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers TranscriptionEngine._initialized = True - -def online_factory(args, asr, tokenizer, logfile=sys.stderr): +def online_factory(args, asr): if args.backend == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor - online = SimulStreamingOnlineProcessor( - asr, - logfile=logfile, - ) + online = SimulStreamingOnlineProcessor(asr) else: - online = OnlineASRProcessor( - asr, - tokenizer, - logfile=logfile, - buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec), - confidence_validation = args.confidence_validation - ) + online = OnlineASRProcessor(asr) return online diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 28e81bc..7af1564 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -89,6 +89,7 @@ def parse_args(): "--model", type=str, default="small", + dest='model_size', help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.", ) @@ -109,6 +110,7 @@ def parse_args(): "--language", type=str, default="auto", + dest='lan', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.", ) parser.add_argument( @@ -189,6 +191,13 @@ def parse_args(): dest="disable_fast_encoder", help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited", ) + + simulstreaming_group.add_argument( + "--custom-alignment-heads", + type=str, + default=None, + help="Use your own alignment heads, useful when `--model-dir` is used", + ) simulstreaming_group.add_argument( "--frame-threshold", diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index e816ed8..a52acc5 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -47,7 +47,6 @@ class SimulStreamingOnlineProcessor: self, asr, logfile=sys.stderr, - warmup_file=None ): self.asr = asr self.logfile = logfile @@ -146,31 +145,20 @@ class SimulStreamingASR(): """SimulStreaming backend with AlignAtt policy.""" sep = "" - def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs): + def __init__(self, logfile=sys.stderr, **kwargs): self.logfile = logfile self.transcribe_kargs = {} - self.original_language = lan - self.model_path = kwargs.get('model_path', './large-v3.pt') - self.frame_threshold = kwargs.get('frame_threshold', 25) - self.audio_max_len = kwargs.get('audio_max_len', 20.0) - self.audio_min_len = kwargs.get('audio_min_len', 0.0) - self.segment_length = kwargs.get('segment_length', 0.5) - self.beams = kwargs.get('beams', 1) - self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam') - self.task = kwargs.get('task', 'transcribe') - self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None) - self.never_fire = kwargs.get('never_fire', False) - self.init_prompt = kwargs.get('init_prompt', None) - self.static_init_prompt = kwargs.get('static_init_prompt', None) - self.max_context_tokens = kwargs.get('max_context_tokens', None) - self.warmup_file = kwargs.get('warmup_file', None) - self.preload_model_count = kwargs.get('preload_model_count', 1) - self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False) + for key, value in kwargs.items(): + setattr(self, key, value) + + if self.decoder_type is None: + self.decoder_type = 'greedy' if self.beams == 1 else 'beam' + self.fast_encoder = False - if model_dir is not None: - self.model_path = model_dir - elif modelsize is not None: + if self.model_dir is not None: + self.model_path = self.model_dir + elif self.model_size is not None: model_mapping = { 'tiny': './tiny.pt', 'base': './base.pt', @@ -185,13 +173,13 @@ class SimulStreamingASR(): 'large-v3': './large-v3.pt', 'large': './large-v3.pt' } - self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt') + self.model_path = model_mapping.get(self.model_size, f'./{self.model_size}.pt') self.cfg = AlignAttConfig( model_path=self.model_path, - segment_length=self.segment_length, + segment_length=self.min_chunk_size, frame_threshold=self.frame_threshold, - language=self.original_language, + language=self.lan, audio_max_len=self.audio_max_len, audio_min_len=self.audio_min_len, cif_ckpt_path=self.cif_ckpt_path, @@ -210,11 +198,15 @@ class SimulStreamingASR(): else: self.tokenizer = None - self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") - self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) + if self.model_dir: + self.model_name = self.model_dir + self.model_path = None + else: + self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") + self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) self.mlx_encoder, self.fw_encoder = None, None - if not self.disable_fast_encoder: + if not self.disable_fast_encoder and not self.model_dir: if HAS_MLX_WHISPER: print('Simulstreaming will use MLX whisper for a faster encoder.') mlx_model_name = mlx_model_mapping[self.model_name] @@ -233,7 +225,12 @@ class SimulStreamingASR(): def load_model(self): - whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder) + whisper_model = load_model( + name=self.model_name, + download_root=self.model_path, + decoder_only=self.fast_encoder, + custom_alignment_heads=self.custom_alignment_heads + ) warmup_audio = load_file(self.warmup_file) if warmup_audio is not None: warmup_audio = torch.from_numpy(warmup_audio).float() @@ -249,7 +246,7 @@ class SimulStreamingASR(): else: # For standard encoder, use the original transcribe warmup warmup_audio = load_file(self.warmup_file) - whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None) + whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None) return whisper_model def get_new_model_instance(self): diff --git a/whisperlivekit/simul_whisper/whisper/__init__.py b/whisperlivekit/simul_whisper/whisper/__init__.py index 069ddbb..5c6db94 100644 --- a/whisperlivekit/simul_whisper/whisper/__init__.py +++ b/whisperlivekit/simul_whisper/whisper/__init__.py @@ -105,7 +105,8 @@ def load_model( device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, - decoder_only=False + decoder_only=False, + custom_alignment_heads=None ) -> Whisper: """ Load a Whisper ASR model @@ -135,15 +136,17 @@ def load_model( download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") if name in _MODELS: - checkpoint_file = _download(_MODELS[name], download_root, in_memory) - alignment_heads = _ALIGNMENT_HEADS[name] + checkpoint_file = _download(_MODELS[name], download_root, in_memory) elif os.path.isfile(name): checkpoint_file = open(name, "rb").read() if in_memory else name - alignment_heads = None else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}" ) + + alignment_heads = _ALIGNMENT_HEADS.get(name, None) + if custom_alignment_heads: + alignment_heads = custom_alignment_heads.encode() with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") diff --git a/whisperlivekit/translation/translation.py b/whisperlivekit/translation/translation.py index 90ce47e..3cdce4d 100644 --- a/whisperlivekit/translation/translation.py +++ b/whisperlivekit/translation/translation.py @@ -21,27 +21,27 @@ class TranslationModel(): device: str tokenizer: dict = field(default_factory=dict) backend_type: str = 'ctranslate2' - model_size: str = '600M' + nllb_size: str = '600M' def get_tokenizer(self, input_lang): if not self.tokenizer.get(input_lang, False): self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained( - f"facebook/nllb-200-distilled-{self.model_size}", + f"facebook/nllb-200-distilled-{self.nllb_size}", src_lang=input_lang, clean_up_tokenization_spaces=True ) return self.tokenizer[input_lang] -def load_model(src_langs, backend='ctranslate2', model_size='600M'): +def load_model(src_langs, nllb_backend='ctranslate2', nllb_size='600M'): device = "cuda" if torch.cuda.is_available() else "cpu" - MODEL = f'nllb-200-distilled-{model_size}-ctranslate2' - if backend=='ctranslate2': + MODEL = f'nllb-200-distilled-{nllb_size}-ctranslate2' + if nllb_backend=='ctranslate2': MODEL_GUY = 'entai2965' huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL) translator = ctranslate2.Translator(MODEL,device=device) - elif backend=='transformers': - translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}") + elif nllb_backend=='transformers': + translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{nllb_size}") tokenizer = dict() for src_lang in src_langs: if src_lang != 'auto': @@ -50,9 +50,9 @@ def load_model(src_langs, backend='ctranslate2', model_size='600M'): translation_model = TranslationModel( translator=translator, tokenizer=tokenizer, - backend_type=backend, + backend_type=nllb_backend, device = device, - model_size = model_size + nllb_size = nllb_size ) for src_lang in src_langs: if src_lang != 'auto': @@ -157,7 +157,7 @@ if __name__ == '__main__': test = test_string.split(' ') step = len(test) // 3 - shared_model = load_model([input_lang], backend='ctranslate2') + shared_model = load_model([input_lang], nllb_backend='ctranslate2') online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang]) beg_inference = time.time() diff --git a/whisperlivekit/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py index 8f7d643..39c04ec 100644 --- a/whisperlivekit/whisper_streaming_custom/backends.py +++ b/whisperlivekit/whisper_streaming_custom/backends.py @@ -11,14 +11,14 @@ class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, # "" for faster-whisper because it emits the spaces when needed) - def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): + def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr): self.logfile = logfile self.transcribe_kargs = {} if lan == "auto": self.original_language = None else: self.original_language = lan - self.model = self.load_model(modelsize, cache_dir, model_dir) + self.model = self.load_model(model_size, cache_dir, model_dir) def with_offset(self, offset: float) -> ASRToken: # This method is kept for compatibility (typically you will use ASRToken.with_offset) @@ -27,7 +27,7 @@ class ASRBase: def __repr__(self): return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" - def load_model(self, modelsize, cache_dir, model_dir): + def load_model(self, model_size, cache_dir, model_dir): raise NotImplementedError("must be implemented in the child class") def transcribe(self, audio, init_prompt=""): @@ -41,7 +41,7 @@ class WhisperTimestampedASR(ASRBase): """Uses whisper_timestamped as the backend.""" sep = " " - def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + def load_model(self, model_size=None, cache_dir=None, model_dir=None): import whisper import whisper_timestamped from whisper_timestamped import transcribe_timestamped @@ -49,7 +49,7 @@ class WhisperTimestampedASR(ASRBase): self.transcribe_timestamped = transcribe_timestamped if model_dir is not None: logger.debug("ignoring model_dir, not implemented") - return whisper.load_model(modelsize, download_root=cache_dir) + return whisper.load_model(model_size, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): result = self.transcribe_timestamped( @@ -88,17 +88,17 @@ class FasterWhisperASR(ASRBase): """Uses faster-whisper as the backend.""" sep = "" - def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + def load_model(self, model_size=None, cache_dir=None, model_dir=None): from faster_whisper import WhisperModel if model_dir is not None: logger.debug(f"Loading whisper model from model_dir {model_dir}. " - f"modelsize and cache_dir parameters are not used.") + f"model_size and cache_dir parameters are not used.") model_size_or_path = model_dir - elif modelsize is not None: - model_size_or_path = modelsize + elif model_size is not None: + model_size_or_path = model_size else: - raise ValueError("Either modelsize or model_dir must be set") + raise ValueError("Either model_size or model_dir must be set") device = "auto" # Allow CTranslate2 to decide available device compute_type = "auto" # Allow CTranslate2 to decide faster compute type @@ -149,18 +149,18 @@ class MLXWhisper(ASRBase): """ sep = "" - def load_model(self, modelsize=None, cache_dir=None, model_dir=None): + def load_model(self, model_size=None, cache_dir=None, model_dir=None): from mlx_whisper.transcribe import ModelHolder, transcribe import mlx.core as mx if model_dir is not None: - logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") + logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.") model_size_or_path = model_dir - elif modelsize is not None: - model_size_or_path = self.translate_model_name(modelsize) - logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") + elif model_size is not None: + model_size_or_path = self.translate_model_name(model_size) + logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.") else: - raise ValueError("Either modelsize or model_dir must be set") + raise ValueError("Either model_size or model_dir must be set") self.model_size_or_path = model_size_or_path dtype = mx.float16 diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index 9e66842..deca515 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -106,9 +106,6 @@ class OnlineASRProcessor: def __init__( self, asr, - tokenize_method: Optional[callable] = None, - buffer_trimming: Tuple[str, float] = ("segment", 15), - confidence_validation = False, logfile=sys.stderr, ): """ @@ -119,13 +116,14 @@ class OnlineASRProcessor: buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment". """ self.asr = asr - self.tokenize = tokenize_method + self.tokenize = asr.tokenizer self.logfile = logfile - self.confidence_validation = confidence_validation + self.confidence_validation = asr.confidence_validation self.global_time_offset = 0.0 self.init() - self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming + self.buffer_trimming_way = asr.buffer_trimming + self.buffer_trimming_sec = asr.buffer_trimming_sec if self.buffer_trimming_way not in ["sentence", "segment"]: raise ValueError("buffer_trimming must be either 'sentence' or 'segment'") diff --git a/whisperlivekit/whisper_streaming_custom/whisper_online.py b/whisperlivekit/whisper_streaming_custom/whisper_online.py index db59027..6fae3ab 100644 --- a/whisperlivekit/whisper_streaming_custom/whisper_online.py +++ b/whisperlivekit/whisper_streaming_custom/whisper_online.py @@ -6,6 +6,7 @@ from functools import lru_cache import time import logging from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR +from whisperlivekit.warmup import warmup_asr logger = logging.getLogger(__name__) @@ -63,11 +64,23 @@ def create_tokenizer(lan): return WtPtok() -def backend_factory(args): - backend = args.backend +def backend_factory( + backend, + lan, + model_size, + model_cache_dir, + model_dir, + task, + buffer_trimming, + buffer_trimming_sec, + confidence_validation, + warmup_file=None, + min_chunk_size=None, + ): + backend = backend if backend == "openai-api": logger.debug("Using OpenAI API.") - asr = OpenaiApiASR(lan=args.lan) + asr = OpenaiApiASR(lan=lan) else: if backend == "faster-whisper": asr_cls = FasterWhisperASR @@ -77,34 +90,33 @@ def backend_factory(args): asr_cls = WhisperTimestampedASR # Only for FasterWhisperASR and WhisperTimestampedASR - size = args.model + t = time.time() - logger.info(f"Loading Whisper {size} model for language {args.lan}...") + logger.info(f"Loading Whisper {model_size} model for language {lan}...") asr = asr_cls( - modelsize=size, - lan=args.lan, - cache_dir=getattr(args, 'model_cache_dir', None), - model_dir=getattr(args, 'model_dir', None), + model_size=model_size, + lan=lan, + cache_dir=model_cache_dir, + model_dir=model_dir, ) e = time.time() logger.info(f"done. It took {round(e-t,2)} seconds.") - # Apply common configurations - if getattr(args, "vad", False): # Checks if VAD argument is present and True - logger.info("Setting VAD filter") - asr.use_vad() - - language = args.lan - if args.task == "translate": - if backend != "simulstreaming": - asr.set_translate_task() + if task == "translate": tgt_language = "en" # Whisper translates into English else: - tgt_language = language # Whisper transcribes in this language + tgt_language = lan # Whisper transcribes in this language # Create the tokenizer - if args.buffer_trimming == "sentence": + if buffer_trimming == "sentence": tokenizer = create_tokenizer(tgt_language) else: tokenizer = None - return asr, tokenizer \ No newline at end of file + + warmup_asr(asr, warmup_file) + + asr.confidence_validation = confidence_validation + asr.tokenizer = tokenizer + asr.buffer_trimming = buffer_trimming + asr.buffer_trimming_sec = buffer_trimming_sec + return asr \ No newline at end of file