Format code (#989)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-08-13 11:52:51 +08:00
committed by GitHub
parent 7293002f53
commit 76b67842ba
10 changed files with 218 additions and 117 deletions

View File

@@ -3,7 +3,7 @@ import os, sys, traceback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
device=sys.argv[1]
device = sys.argv[1]
n_part = int(sys.argv[2])
i_part = int(sys.argv[3])
if len(sys.argv) == 6:
@@ -20,7 +20,7 @@ import soundfile as sf
import numpy as np
import fairseq
if("privateuseone"not in device):
if "privateuseone" not in device:
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
@@ -28,12 +28,15 @@ if("privateuseone"not in device):
device = "mps"
else:
import torch_directml
device = torch_directml.device(torch_directml.default_device())
def forward_dml(ctx, x, scale):
ctx.scale = scale
res = x.clone().detach()
return res
fairseq.modules.grad_multiply.GradMultiply.forward=forward_dml
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
f = open("%s/extract_f0_feature.log" % exp_dir, "a+")