Format code (#989)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-08-13 11:52:51 +08:00
committed by GitHub
parent 7293002f53
commit 76b67842ba
10 changed files with 218 additions and 117 deletions

View File

@@ -83,12 +83,13 @@ def get_models(device, dim_f, dim_t, n_fft):
warnings.filterwarnings("ignore")
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from config import Config
cpu = torch.device("cpu")
device=Config().device
device = Config().device
# if torch.cuda.is_available():
# device = torch.device("cuda:0")
# elif torch.backends.mps.is_available():
@@ -104,10 +105,15 @@ class Predictor:
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
)
import onnxruntime as ort
print(ort.get_available_providers())
self.model = ort.InferenceSession(
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
providers=["CUDAExecutionProvider", "DmlExecutionProvider","CPUExecutionProvider"],
providers=[
"CUDAExecutionProvider",
"DmlExecutionProvider",
"CPUExecutionProvider",
],
)
print("onnx load done")