From 0d874fb5151a7b05e5fbabb1bb47e32458a179cc Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 7 Feb 2025 10:16:03 +0100 Subject: [PATCH] cuda or cpu auto detection --- src/whisper_streaming/backends.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/whisper_streaming/backends.py b/src/whisper_streaming/backends.py index 99ba762..20522ed 100644 --- a/src/whisper_streaming/backends.py +++ b/src/whisper_streaming/backends.py @@ -4,7 +4,7 @@ import logging import io import soundfile as sf import math - +import torch logger = logging.getLogger(__name__) @@ -102,11 +102,13 @@ class FasterWhisperASR(ASRBase): else: raise ValueError("modelsize or model_dir parameter must be set") - # this worked fast and reliably on NVIDIA L40 + device = "cuda" if torch.cuda.is_available() else "cpu" + compute_type = "float16" if device == "cuda" else "float32" + model = WhisperModel( model_size_or_path, - device="cuda", - compute_type="float16", + device=device, + compute_type=compute_type, download_root=cache_dir, )