mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-01-20 02:51:09 +00:00
fix
This commit is contained in:
@@ -34,9 +34,6 @@ global_step = 0
|
||||
|
||||
|
||||
def main():
|
||||
"""Assume Single Node Multi GPUs Training Only"""
|
||||
assert torch.cuda.is_available(), "CPU training is not allowed."
|
||||
|
||||
# n_gpus = torch.cuda.device_count()
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "5555"
|
||||
@@ -65,7 +62,7 @@ def run(rank, n_gpus, hps):
|
||||
backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
|
||||
)
|
||||
torch.manual_seed(hps.train.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
if torch.cuda.is_available(): torch.cuda.set_device(rank)
|
||||
|
||||
if (hps.if_f0 == 1):train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
|
||||
else:train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
|
||||
@@ -92,9 +89,13 @@ def run(rank, n_gpus, hps):
|
||||
persistent_workers=True,
|
||||
prefetch_factor=8,
|
||||
)
|
||||
if(hps.if_f0==1):net_g = SynthesizerTrnMs256NSFsid(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run,sr=hps.sample_rate).cuda(rank)
|
||||
else:net_g = SynthesizerTrnMs256NSFsid_nono(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run).cuda(rank)
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
if(hps.if_f0==1):
|
||||
net_g = SynthesizerTrnMs256NSFsid(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run,sr=hps.sample_rate)
|
||||
else:
|
||||
net_g = SynthesizerTrnMs256NSFsid_nono(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run)
|
||||
if torch.cuda.is_available(): net_g = net_g.cuda(rank)
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
|
||||
if torch.cuda.is_available(): net_d = net_d.cuda(rank)
|
||||
optim_g = torch.optim.AdamW(
|
||||
net_g.parameters(),
|
||||
hps.train.learning_rate,
|
||||
@@ -109,8 +110,12 @@ def run(rank, n_gpus, hps):
|
||||
)
|
||||
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||
net_g = DDP(net_g, device_ids=[rank])
|
||||
net_d = DDP(net_d, device_ids=[rank])
|
||||
if torch.cuda.is_available():
|
||||
net_g = DDP(net_g, device_ids=[rank])
|
||||
net_d = DDP(net_d, device_ids=[rank])
|
||||
else:
|
||||
net_g = DDP(net_g)
|
||||
net_d = DDP(net_d)
|
||||
|
||||
try:#如果能加载自动resume
|
||||
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) # D多半加载没事
|
||||
@@ -190,11 +195,12 @@ def train_and_evaluate(
|
||||
for batch_idx, info in enumerate(train_loader):
|
||||
if (hps.if_f0 == 1):phone,phone_lengths,pitch,pitchf,spec,spec_lengths,wave,wave_lengths,sid=info
|
||||
else:phone,phone_lengths,spec,spec_lengths,wave,wave_lengths,sid=info
|
||||
phone, phone_lengths = phone.cuda(rank, non_blocking=True),phone_lengths.cuda(rank, non_blocking=True )
|
||||
if (hps.if_f0 == 1):pitch,pitchf = pitch.cuda(rank, non_blocking=True),pitchf.cuda(rank, non_blocking=True)
|
||||
sid = sid.cuda(rank, non_blocking=True)
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
|
||||
wave, wave_lengths = wave.cuda(rank, non_blocking=True), wave_lengths.cuda(rank, non_blocking=True)
|
||||
if torch.cuda.is_available():
|
||||
phone, phone_lengths = phone.cuda(rank, non_blocking=True), phone_lengths.cuda(rank, non_blocking=True )
|
||||
if (hps.if_f0 == 1):pitch,pitchf = pitch.cuda(rank, non_blocking=True),pitchf.cuda(rank, non_blocking=True)
|
||||
sid = sid.cuda(rank, non_blocking=True)
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
|
||||
wave, wave_lengths = wave.cuda(rank, non_blocking=True), wave_lengths.cuda(rank, non_blocking=True)
|
||||
if(hps.if_cache_data_in_gpu==True):
|
||||
if (hps.if_f0 == 1):cache.append((batch_idx, (phone,phone_lengths,pitch,pitchf,spec,spec_lengths,wave,wave_lengths ,sid)))
|
||||
else:cache.append((batch_idx, (phone,phone_lengths,spec,spec_lengths,wave,wave_lengths ,sid)))
|
||||
|
||||
Reference in New Issue
Block a user