cuda or cpu auto detection

This commit is contained in:
Quentin Fuxa
2025-02-07 10:16:03 +01:00
parent 4d1aa4421a
commit 0d874fb515

View File

@@ -4,7 +4,7 @@ import logging
import io
import soundfile as sf
import math
import torch
logger = logging.getLogger(__name__)
@@ -102,11 +102,13 @@ class FasterWhisperASR(ASRBase):
else:
raise ValueError("modelsize or model_dir parameter must be set")
# this worked fast and reliably on NVIDIA L40
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "float32"
model = WhisperModel(
model_size_or_path,
device="cuda",
compute_type="float16",
device=device,
compute_type=compute_type,
download_root=cache_dir,
)