diff --git a/src/whisper_streaming/backends.py b/src/whisper_streaming/backends.py index 99ba762..20522ed 100644 --- a/src/whisper_streaming/backends.py +++ b/src/whisper_streaming/backends.py @@ -4,7 +4,7 @@ import logging import io import soundfile as sf import math - +import torch logger = logging.getLogger(__name__) @@ -102,11 +102,13 @@ class FasterWhisperASR(ASRBase): else: raise ValueError("modelsize or model_dir parameter must be set") - # this worked fast and reliably on NVIDIA L40 + device = "cuda" if torch.cuda.is_available() else "cpu" + compute_type = "float16" if device == "cuda" else "float32" + model = WhisperModel( model_size_or_path, - device="cuda", - compute_type="float16", + device=device, + compute_type=compute_type, download_root=cache_dir, )