optimize: 精简未用到的配置项并在特征提取初步引入mps (#32)

This commit is contained in:
源文雨
2023-04-11 18:14:55 +08:00
committed by GitHub
parent 0656591373
commit ecc744d748
10 changed files with 82 additions and 57 deletions

View File

@@ -1,13 +1,12 @@
import os,sys,traceback
if len(sys.argv) == 4:
n_part=int(sys.argv[1])
i_part=int(sys.argv[2])
exp_dir=sys.argv[3]
else:
n_part=int(sys.argv[1])
i_part=int(sys.argv[2])
i_gpu=sys.argv[3]
device=sys.argv[1]
n_part=int(sys.argv[2])
i_part=int(sys.argv[3])
if len(sys.argv) == 5:
exp_dir=sys.argv[4]
else:
i_gpu=sys.argv[4]
exp_dir=sys.argv[5]
os.environ["CUDA_VISIBLE_DEVICES"]=str(i_gpu)
import torch
@@ -15,7 +14,6 @@ import torch.nn.functional as F
import soundfile as sf
import numpy as np
from fairseq import checkpoint_utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
f = open("%s/extract_f0_feature.log"%exp_dir, "a+")
def printt(strr):
@@ -50,8 +48,8 @@ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
)
model = models[0]
model = model.to(device)
if torch.cuda.is_available():
model = model.half()
printt("move model to "+device)
if device != "cpu": model = model.half()
model.eval()
todo=sorted(list(os.listdir(wavPath)))[i_part::n_part]
@@ -70,7 +68,7 @@ else:
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device) if torch.cuda.is_available() else feats.to(device),
"source": feats.half().to(device) if device != "cpu" else feats.to(device),
"padding_mask": padding_mask.to(device),
"output_layer": 9, # layer 9
}