mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-01-20 02:51:09 +00:00
Add files via upload
This commit is contained in:
@@ -31,14 +31,21 @@ from data_utils import (
|
||||
TextAudioCollate,
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from infer_pack.models import (
|
||||
SynthesizerTrnMs256NSFsid,
|
||||
SynthesizerTrnMs256NSFsid_nono,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
if(hps.version=="v1"):
|
||||
from infer_pack.models import (
|
||||
SynthesizerTrnMs256NSFsid as RVC_Model_f0,
|
||||
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
else:
|
||||
from infer_pack.models import (
|
||||
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
|
||||
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
|
||||
MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
|
||||
)
|
||||
from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
|
||||
from process_ckpt import savee
|
||||
|
||||
global_step = 0
|
||||
|
||||
@@ -63,7 +70,7 @@ def run(rank, n_gpus, hps):
|
||||
if rank == 0:
|
||||
logger = utils.get_logger(hps.model_dir)
|
||||
logger.info(hps)
|
||||
utils.check_git_hash(hps.model_dir)
|
||||
# utils.check_git_hash(hps.model_dir)
|
||||
writer = SummaryWriter(log_dir=hps.model_dir)
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
|
||||
|
||||
@@ -104,7 +111,7 @@ def run(rank, n_gpus, hps):
|
||||
prefetch_factor=8,
|
||||
)
|
||||
if hps.if_f0 == 1:
|
||||
net_g = SynthesizerTrnMs256NSFsid(
|
||||
net_g = RVC_Model_f0(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model,
|
||||
@@ -112,7 +119,7 @@ def run(rank, n_gpus, hps):
|
||||
sr=hps.sample_rate,
|
||||
)
|
||||
else:
|
||||
net_g = SynthesizerTrnMs256NSFsid_nono(
|
||||
net_g = RVC_Model_nof0(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
**hps.model,
|
||||
@@ -343,7 +350,7 @@ def train_and_evaluate(
|
||||
spec = spec.cuda(rank, non_blocking=True)
|
||||
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
||||
wave = wave.cuda(rank, non_blocking=True)
|
||||
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||
|
||||
# Calculate
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
@@ -428,10 +435,10 @@ def train_and_evaluate(
|
||||
)
|
||||
)
|
||||
# Amor For Tensorboard display
|
||||
if loss_mel > 50:
|
||||
loss_mel = 50
|
||||
if loss_kl > 5:
|
||||
loss_kl = 5
|
||||
if loss_mel > 75:
|
||||
loss_mel = 75
|
||||
if loss_kl > 9:
|
||||
loss_kl = 9
|
||||
|
||||
logger.info([global_step, lr])
|
||||
logger.info(
|
||||
@@ -512,12 +519,20 @@ def train_and_evaluate(
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
||||
)
|
||||
if(rank==0 and hps.save_every_weights=="1"):
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
else:
|
||||
ckpt = net_g.state_dict()
|
||||
logger.info(
|
||||
"saving ckpt %s_e%s:%s"
|
||||
% (hps.name,epoch,savee(ckpt, hps.sample_rate, hps.if_f0, hps.name+"_e%s"%epoch, epoch,hps.version))
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
logger.info("====> Epoch: {}".format(epoch))
|
||||
if epoch >= hps.total_epoch and rank == 0:
|
||||
logger.info("Training is done. The program is closed.")
|
||||
from process_ckpt import savee # def savee(ckpt,sr,if_f0,name,epoch):
|
||||
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
@@ -525,7 +540,7 @@ def train_and_evaluate(
|
||||
ckpt = net_g.state_dict()
|
||||
logger.info(
|
||||
"saving final ckpt:%s"
|
||||
% (savee(ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch))
|
||||
% (savee(ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch,hps.version))
|
||||
)
|
||||
sleep(1)
|
||||
os._exit(2333333)
|
||||
|
||||
Reference in New Issue
Block a user