fix: index_root searching

close #1147
This commit is contained in:
源文雨
2023-09-01 14:11:55 +08:00
parent d634c2727e
commit 8ffdcb0128
21 changed files with 59 additions and 76 deletions

View File

@@ -85,7 +85,7 @@ class FeatureInput(object):
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
print("loading rmvpe model")
print("Loading rmvpe model")
self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", is_half=False, device="cpu"
)

View File

@@ -48,7 +48,7 @@ class FeatureInput(object):
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
print("loading rmvpe model")
print("Loading rmvpe model")
self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", is_half=is_half, device="cuda"
)

View File

@@ -46,7 +46,7 @@ class FeatureInput(object):
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
print("loading rmvpe model")
print("Loading rmvpe model")
self.model_rmvpe = RMVPE(
"assets/rmvpe/rmvpe.pt", is_half=False, device=device
)

View File

@@ -101,7 +101,7 @@ class Predictor:
"CPUExecutionProvider",
],
)
print("onnx load done")
print("ONNX load done")
def demix(self, mix):
samples = mix.shape[-1]

View File

@@ -90,7 +90,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
del pre_fun
except:
traceback.print_exc()
print("clean_empty_cache")
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Executed torch.cuda.empty_cache()")
yield "\n".join(infos)

View File

@@ -31,7 +31,7 @@ class VC:
def get_vc(self, sid, *to_return_protect):
person = f'{os.getenv("weight_root")}/{sid}'
print(f"loading {person}")
print(f"Loading: {person}")
self.cpt = torch.load(person, map_location="cpu")
self.tgt_sr = self.cpt["config"][-1]
@@ -77,6 +77,7 @@ class VC:
self.pipeline = Pipeline(self.tgt_sr, self.config)
n_spk = self.cpt["config"][-3]
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
print("Select index:", index["value"])
return (
(

View File

@@ -140,7 +140,7 @@ class Pipeline(object):
from infer.lib.rmvpe import RMVPE
print(
"loading rmvpe model,%s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"]
"Loading rmvpe model,%s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"]
)
self.model_rmvpe = RMVPE(
"%s/rmvpe.pt" % os.environ["rmvpe_root"],
@@ -152,7 +152,7 @@ class Pipeline(object):
if "privateuseone" in str(self.device): # clean ortruntime memory
del self.model_rmvpe.model
del self.model_rmvpe
print("cleaning ortruntime memory")
print("Cleaning ortruntime memory")
f0 *= pow(2, f0_up_key / 12)
# with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
@@ -262,17 +262,12 @@ class Pipeline(object):
feats = feats.to(feats0.dtype)
p_len = torch.tensor([p_len], device=self.device).long()
with torch.no_grad():
if pitch is not None and pitchf is not None:
audio1 = (
(net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
.data.cpu()
.float()
.numpy()
)
else:
audio1 = (
(net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
)
hasp = pitch is not None and pitchf is not None
arg = (feats, p_len, pitch, pitchf, sid) if hasp else (feats, p_len, sid)
audio1 = (
(net_g.infer(*arg)[0][0, 0]).data.cpu().float().numpy()
)
del hasp, arg
del feats, p_len, padding_mask
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -2,33 +2,20 @@ import os
from fairseq import checkpoint_utils
### don't modify the code before you test it
# def get_index_path_from_model(sid):
# return next(
# (
# f
# for f in [
# os.path.join(root, name)
# for root, dirs, files in os.walk(os.getenv("index_root"), topdown=False)
# for name in files
# if name.endswith(".index") and "trained" not in name
# ]
# if sid.split(".")[0] in f
# ),
# "",
# )
def get_index_path_from_model(sid):
sel_index_path = ""
name = os.path.join("logs", sid.split(".")[0], "")
# print(name)
for f in index_paths:
if name in f:
# print("selected index path:", f)
sel_index_path = f
break
return sel_index_path
return next(
(
f
for f in [
os.path.join(root, name)
for root, _, files in os.walk(os.getenv("index_root"), topdown=False)
for name in files
if name.endswith(".index") and "trained" not in name
]
if sid.split(".")[0] in f
),
"",
)
def load_hubert(config):