fix: 卸载音色省显存

顺便将所有print换成了统一的logger
This commit is contained in:
源文雨
2023-09-01 15:18:08 +08:00
parent 8d5a77dbe9
commit 04a33b9709
23 changed files with 189 additions and 106 deletions

View File

@@ -1,5 +1,7 @@
import os
import traceback
import logging
logger = logging.getLogger(__name__)
import numpy as np
import torch
@@ -110,7 +112,7 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
try:
spec = torch.load(spec_filename)
except:
print(spec_filename, traceback.format_exc())
logger.warn(spec_filename, traceback.format_exc())
spec = spectrogram_torch(
audio_norm,
self.filter_length,
@@ -302,7 +304,7 @@ class TextAudioLoader(torch.utils.data.Dataset):
try:
spec = torch.load(spec_filename)
except:
print(spec_filename, traceback.format_exc())
logger.warn(spec_filename, traceback.format_exc())
spec = spectrogram_torch(
audio_norm,
self.filter_length,

View File

@@ -1,6 +1,8 @@
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
import logging
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0
@@ -51,9 +53,9 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
"""
# Validation
if torch.min(y) < -1.07:
print("spectrogram_torch min value is ", torch.min(y))
logger.debug("min value is ", torch.min(y))
if torch.max(y) > 1.07:
print("spectrogram_torch max value is ", torch.max(y))
logger.debug("max value is ", torch.max(y))
# Window - Cache if needed
global hann_window

View File

@@ -33,7 +33,7 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
logger.warn(
"shape-%s-mismatch. need: %s, get: %s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
@@ -109,7 +109,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
logger.warn(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
@@ -207,7 +207,7 @@ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
print(x)
logger.debug(x)
return x