fix: 卸载音色省显存

顺便将所有print换成了统一的logger
This commit is contained in:
源文雨
2023-09-01 15:18:08 +08:00
parent 8d5a77dbe9
commit 04a33b9709
23 changed files with 189 additions and 106 deletions

View File

@@ -1,7 +1,6 @@
import math
import os
import pdb
from time import time as ttime
import logging
logger = logging.getLogger(__name__)
import numpy as np
import torch
@@ -616,7 +615,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
logger.debug("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
@@ -732,7 +731,7 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
logger.debug("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
@@ -845,7 +844,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
logger.debug("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
@@ -951,7 +950,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
logger.debug("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
def remove_weight_norm(self):
self.dec.remove_weight_norm()

View File

@@ -1,7 +1,6 @@
import math
import os
import pdb
from time import time as ttime
import logging
logger = logging.getLogger(__name__)
import numpy as np
import torch
@@ -620,7 +619,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
self.speaker_map = None
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
logger.debug("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
def remove_weight_norm(self):
self.dec.remove_weight_norm()

View File

@@ -3,10 +3,13 @@ import numpy as np
import onnxruntime
import soundfile
import logging
logger = logging.getLogger(__name__)
class ContentVec:
def __init__(self, vec_path="pretrained/vec-768-layer-12.onnx", device=None):
print("Load model(s) from {}".format(vec_path))
logger.info("Load model(s) from {}".format(vec_path))
if device == "cpu" or device is None:
providers = ["CPUExecutionProvider"]
elif device == "cuda":

View File

@@ -7,6 +7,10 @@ import torch.nn.functional as F
from librosa.util import normalize, pad_center, tiny
from scipy.signal import get_window
import logging
logger = logging.getLogger(__name__)
###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
def window_sumsquare(
@@ -691,4 +695,4 @@ if __name__ == "__main__":
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
# f0 = rmvpe.infer_from_audio(audio, thred=thred)
t1 = ttime()
print(f0.shape, t1 - t0)
logger.info(f0.shape, t1 - t0)

View File

@@ -1,5 +1,7 @@
import os
import traceback
import logging
logger = logging.getLogger(__name__)
import numpy as np
import torch
@@ -110,7 +112,7 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
try:
spec = torch.load(spec_filename)
except:
print(spec_filename, traceback.format_exc())
logger.warn(spec_filename, traceback.format_exc())
spec = spectrogram_torch(
audio_norm,
self.filter_length,
@@ -302,7 +304,7 @@ class TextAudioLoader(torch.utils.data.Dataset):
try:
spec = torch.load(spec_filename)
except:
print(spec_filename, traceback.format_exc())
logger.warn(spec_filename, traceback.format_exc())
spec = spectrogram_torch(
audio_norm,
self.filter_length,

View File

@@ -1,6 +1,8 @@
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
import logging
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0
@@ -51,9 +53,9 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
"""
# Validation
if torch.min(y) < -1.07:
print("spectrogram_torch min value is ", torch.min(y))
logger.debug("min value is ", torch.min(y))
if torch.max(y) > 1.07:
print("spectrogram_torch max value is ", torch.max(y))
logger.debug("max value is ", torch.max(y))
# Window - Cache if needed
global hann_window

View File

@@ -33,7 +33,7 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
logger.warn(
"shape-%s-mismatch. need: %s, get: %s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
@@ -109,7 +109,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
logger.warn(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
@@ -207,7 +207,7 @@ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
print(x)
logger.debug(x)
return x