mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-01-20 02:51:09 +00:00
Reformat and rewrite _get_name_params (#57)
* Reformat
* rewrite _get_name_params
* Add workflow for automatic formatting
* Revert "Add workflow for automatic formatting"
This reverts commit 9111c5dbc1.
* revert Retrieval_based_Voice_Conversion_WebUI.ipynb
---------
Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import os,traceback
|
||||
import os, traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
@@ -6,6 +6,7 @@ import torch.utils.data
|
||||
from mel_processing import spectrogram_torch
|
||||
from utils import load_wav_to_torch, load_filepaths_and_text
|
||||
|
||||
|
||||
class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, text pairs
|
||||
@@ -15,14 +16,14 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, audiopaths_and_text, hparams):
|
||||
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||
self._filter()
|
||||
|
||||
def _filter(self):
|
||||
@@ -34,12 +35,13 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
# spec_length = wav_length // hop_length
|
||||
audiopaths_and_text_new = []
|
||||
lengths = []
|
||||
for audiopath, text, pitch,pitchf,dv in self.audiopaths_and_text:
|
||||
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
||||
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||
audiopaths_and_text_new.append([audiopath, text, pitch,pitchf,dv])
|
||||
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
||||
lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
|
||||
self.audiopaths_and_text = audiopaths_and_text_new
|
||||
self.lengths = lengths
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
return sid
|
||||
@@ -54,7 +56,7 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
|
||||
phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
|
||||
spec, wav = self.get_audio(file)
|
||||
dv=self.get_sid(dv)
|
||||
dv = self.get_sid(dv)
|
||||
|
||||
len_phone = phone.size()[0]
|
||||
len_spec = spec.size()[-1]
|
||||
@@ -71,9 +73,9 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
pitch = pitch[:len_min]
|
||||
pitchf = pitchf[:len_min]
|
||||
|
||||
return (spec, wav, phone, pitch,pitchf,dv)
|
||||
return (spec, wav, phone, pitch, pitchf, dv)
|
||||
|
||||
def get_labels(self, phone, pitch,pitchf):
|
||||
def get_labels(self, phone, pitch, pitchf):
|
||||
phone = np.load(phone)
|
||||
phone = np.repeat(phone, 2, axis=0)
|
||||
pitch = np.load(pitch)
|
||||
@@ -86,7 +88,7 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
phone = torch.FloatTensor(phone)
|
||||
pitch = torch.LongTensor(pitch)
|
||||
pitchf = torch.FloatTensor(pitchf)
|
||||
return phone, pitch,pitchf
|
||||
return phone, pitch, pitchf
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
@@ -103,10 +105,15 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
try:
|
||||
spec = torch.load(spec_filename)
|
||||
except:
|
||||
print (spec_filename,traceback.format_exc())
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
print(spec_filename, traceback.format_exc())
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
@@ -127,6 +134,8 @@ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
class TextAudioCollateMultiNSFsid:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
@@ -155,7 +164,9 @@ class TextAudioCollateMultiNSFsid:
|
||||
|
||||
max_phone_len = max([x[2].size(0) for x in batch])
|
||||
phone_lengths = torch.LongTensor(len(batch))
|
||||
phone_padded = torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])#(spec, wav, phone, pitch)
|
||||
phone_padded = torch.FloatTensor(
|
||||
len(batch), max_phone_len, batch[0][2].shape[1]
|
||||
) # (spec, wav, phone, pitch)
|
||||
pitch_padded = torch.LongTensor(len(batch), max_phone_len)
|
||||
pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
|
||||
phone_padded.zero_()
|
||||
@@ -187,7 +198,6 @@ class TextAudioCollateMultiNSFsid:
|
||||
# dv[i] = row[5]
|
||||
sid[i] = row[5]
|
||||
|
||||
|
||||
return (
|
||||
phone_padded,
|
||||
phone_lengths,
|
||||
@@ -198,9 +208,10 @@ class TextAudioCollateMultiNSFsid:
|
||||
wave_padded,
|
||||
wave_lengths,
|
||||
# dv
|
||||
sid
|
||||
sid,
|
||||
)
|
||||
|
||||
|
||||
class TextAudioLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, text pairs
|
||||
@@ -210,14 +221,14 @@ class TextAudioLoader(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, audiopaths_and_text, hparams):
|
||||
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
||||
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
||||
self._filter()
|
||||
|
||||
def _filter(self):
|
||||
@@ -229,12 +240,13 @@ class TextAudioLoader(torch.utils.data.Dataset):
|
||||
# spec_length = wav_length // hop_length
|
||||
audiopaths_and_text_new = []
|
||||
lengths = []
|
||||
for audiopath, text,dv in self.audiopaths_and_text:
|
||||
for audiopath, text, dv in self.audiopaths_and_text:
|
||||
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||
audiopaths_and_text_new.append([audiopath, text,dv])
|
||||
audiopaths_and_text_new.append([audiopath, text, dv])
|
||||
lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
|
||||
self.audiopaths_and_text = audiopaths_and_text_new
|
||||
self.lengths = lengths
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
return sid
|
||||
@@ -247,7 +259,7 @@ class TextAudioLoader(torch.utils.data.Dataset):
|
||||
|
||||
phone = self.get_labels(phone)
|
||||
spec, wav = self.get_audio(file)
|
||||
dv=self.get_sid(dv)
|
||||
dv = self.get_sid(dv)
|
||||
|
||||
len_phone = phone.size()[0]
|
||||
len_spec = spec.size()[-1]
|
||||
@@ -257,7 +269,7 @@ class TextAudioLoader(torch.utils.data.Dataset):
|
||||
spec = spec[:, :len_min]
|
||||
wav = wav[:, :len_wav]
|
||||
phone = phone[:len_min, :]
|
||||
return (spec, wav, phone,dv)
|
||||
return (spec, wav, phone, dv)
|
||||
|
||||
def get_labels(self, phone):
|
||||
phone = np.load(phone)
|
||||
@@ -282,10 +294,15 @@ class TextAudioLoader(torch.utils.data.Dataset):
|
||||
try:
|
||||
spec = torch.load(spec_filename)
|
||||
except:
|
||||
print (spec_filename,traceback.format_exc())
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
print(spec_filename, traceback.format_exc())
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
@@ -306,6 +323,8 @@ class TextAudioLoader(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
class TextAudioCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
@@ -334,7 +353,9 @@ class TextAudioCollate:
|
||||
|
||||
max_phone_len = max([x[2].size(0) for x in batch])
|
||||
phone_lengths = torch.LongTensor(len(batch))
|
||||
phone_padded = torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
||||
phone_padded = torch.FloatTensor(
|
||||
len(batch), max_phone_len, batch[0][2].shape[1]
|
||||
)
|
||||
phone_padded.zero_()
|
||||
sid = torch.LongTensor(len(batch))
|
||||
|
||||
@@ -355,7 +376,6 @@ class TextAudioCollate:
|
||||
|
||||
sid[i] = row[3]
|
||||
|
||||
|
||||
return (
|
||||
phone_padded,
|
||||
phone_lengths,
|
||||
@@ -363,9 +383,10 @@ class TextAudioCollate:
|
||||
spec_lengths,
|
||||
wave_padded,
|
||||
wave_lengths,
|
||||
sid
|
||||
sid,
|
||||
)
|
||||
|
||||
|
||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
"""
|
||||
Maintain similar input lengths in a batch.
|
||||
@@ -402,7 +423,7 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
if idx_bucket != -1:
|
||||
buckets[idx_bucket].append(i)
|
||||
|
||||
for i in range(len(buckets) - 1, -1, -1):#
|
||||
for i in range(len(buckets) - 1, -1, -1): #
|
||||
if len(buckets[i]) == 0:
|
||||
buckets.pop(i)
|
||||
self.boundaries.pop(i + 1)
|
||||
|
||||
Reference in New Issue
Block a user