fix: 卸载音色省显存

顺便将所有print换成了统一的logger
This commit is contained in:
源文雨
2023-09-01 15:18:08 +08:00
parent 8d5a77dbe9
commit 04a33b9709
23 changed files with 189 additions and 106 deletions

View File

@@ -1,5 +1,7 @@
import os
import sys
import logging
logger = logging.getLogger(__name__)
now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))
@@ -82,7 +84,7 @@ def main():
n_gpus = 1
if n_gpus < 1:
# patch to unblock people without gpus. there is probably a better way.
print("NO GPU DETECTED: falling back to CPU - this may take a while")
logger.warn("NO GPU DETECTED: falling back to CPU - this may take a while")
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
@@ -209,7 +211,7 @@ def run(rank, n_gpus, hps):
if hps.pretrainG != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG))
print(
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
@@ -217,7 +219,7 @@ def run(rank, n_gpus, hps):
if hps.pretrainD != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD))
print(
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)

View File

@@ -1,4 +1,6 @@
import os
import logging
logger = logging.getLogger(__name__)
import librosa
import numpy as np
@@ -88,7 +90,7 @@ class Predictor:
def __init__(self, args):
import onnxruntime as ort
print(ort.get_available_providers())
logger.info(ort.get_available_providers())
self.args = args
self.model_ = get_models(
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
@@ -101,7 +103,7 @@ class Predictor:
"CPUExecutionProvider",
],
)
print("ONNX load done")
logger.info("ONNX load done")
def demix(self, mix):
samples = mix.shape[-1]

View File

@@ -1,5 +1,7 @@
import os
import traceback
import logging
logger = logging.getLogger(__name__)
import ffmpeg
import torch
@@ -92,5 +94,5 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
traceback.print_exc()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Executed torch.cuda.empty_cache()")
logger.info("Executed torch.cuda.empty_cache()")
yield "\n".join(infos)

View File

@@ -1,4 +1,6 @@
import os
import logging
logger = logging.getLogger(__name__)
import librosa
import numpy as np
@@ -116,7 +118,7 @@ class AudioPre:
)
else:
wav_instrument = spec_utils.cmb_spectrogram_to_wave(y_spec_m, self.mp)
print("%s instruments done" % name)
logger.info("%s instruments done" % name)
if format in ["wav", "flac"]:
sf.write(
os.path.join(
@@ -150,7 +152,7 @@ class AudioPre:
)
else:
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
print("%s vocals done" % name)
logger.info("%s vocals done" % name)
if format in ["wav", "flac"]:
sf.write(
os.path.join(
@@ -283,7 +285,7 @@ class AudioPreDeEcho:
)
else:
wav_instrument = spec_utils.cmb_spectrogram_to_wave(y_spec_m, self.mp)
print("%s instruments done" % name)
logger.info("%s instruments done" % name)
if format in ["wav", "flac"]:
sf.write(
os.path.join(
@@ -317,7 +319,7 @@ class AudioPreDeEcho:
)
else:
wav_vocals = spec_utils.cmb_spectrogram_to_wave(v_spec_m, self.mp)
print("%s vocals done" % name)
logger.info("%s vocals done" % name)
if format in ["wav", "flac"]:
sf.write(
os.path.join(

View File

@@ -1,4 +1,6 @@
import traceback
import logging
logger = logging.getLogger(__name__)
import numpy as np
import soundfile as sf
@@ -30,14 +32,7 @@ class VC:
self.config = config
def get_vc(self, sid, *to_return_protect):
person = f'{os.getenv("weight_root")}/{sid}'
print(f"Loading: {person}")
self.cpt = torch.load(person, map_location="cpu")
self.tgt_sr = self.cpt["config"][-1]
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
self.if_f0 = self.cpt.get("f0", 1)
self.version = self.cpt.get("version", "v1")
logger.info("Get sid: " + sid)
to_return_protect0 = {
"visible": self.if_f0 != 0,
@@ -54,6 +49,57 @@ class VC:
"__type__": "update",
}
if not sid:
if self.hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
logger.info("Clean model cache")
del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr # ,cpt
self.hubert_model = self.net_g = self.n_spk = self.vc = self.hubert_model = self.tgt_sr = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
###楼下不这么折腾清理不干净
self.if_f0 = self.cpt.get("f0", 1)
self.version = self.cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*self.cpt["config"], is_half=self.config.is_half
)
else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"])
elif self.version == "v2":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid(
*self.cpt["config"], is_half=self.config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"])
del self.net_g, self.cpt
if torch.cuda.is_available():
torch.cuda.empty_cache()
return (
{"visible": False, "__type__": "update"},
{
"visible": True,
"value": to_return_protect0,
"__type__": "update",
},
{
"visible": True,
"value": to_return_protect1,
"__type__": "update",
},
"",
"",
)
person = f'{os.getenv("weight_root")}/{sid}'
logger.info(f"Loading: {person}")
self.cpt = torch.load(person, map_location="cpu")
self.tgt_sr = self.cpt["config"][-1]
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
self.if_f0 = self.cpt.get("f0", 1)
self.version = self.cpt.get("version", "v1")
synthesizer_class = {
("v1", 1): SynthesizerTrnMs256NSFsid,
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
@@ -77,7 +123,7 @@ class VC:
self.pipeline = Pipeline(self.tgt_sr, self.config)
n_spk = self.cpt["config"][-3]
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
print("Select index:", index["value"])
logger.info("Select index: " + index["value"])
return (
(
@@ -165,7 +211,7 @@ class VC:
)
except:
info = traceback.format_exc()
print(info)
logger.warn(info)
return info, (None, None)
def vc_multi(

View File

@@ -1,6 +1,9 @@
import os
import sys
import traceback
import logging
logger = logging.getLogger(__name__)
from functools import lru_cache
from time import time as ttime
@@ -139,7 +142,7 @@ class Pipeline(object):
if not hasattr(self, "model_rmvpe"):
from infer.lib.rmvpe import RMVPE
print(
logger.info(
"Loading rmvpe model,%s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"]
)
self.model_rmvpe = RMVPE(
@@ -152,7 +155,7 @@ class Pipeline(object):
if "privateuseone" in str(self.device): # clean ortruntime memory
del self.model_rmvpe.model
del self.model_rmvpe
print("Cleaning ortruntime memory")
logger.info("Cleaning ortruntime memory")
f0 *= pow(2, f0_up_key / 12)
# with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))