mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-01-20 11:00:23 +00:00
chore(sync): merge dev into main (#1379)
* Optimize latency (#1259) * add attribute: configs/config.py Optimize latency: tools/rvc_for_realtime.py * new file: assets/Synthesizer_inputs.pth * fix: configs/config.py fix: tools/rvc_for_realtime.py * fix bug: infer/lib/infer_pack/models.py * new file: assets/hubert_inputs.pth new file: assets/rmvpe_inputs.pth modified: configs/config.py new features: infer/lib/rmvpe.py new features: tools/jit_export/__init__.py new features: tools/jit_export/get_hubert.py new features: tools/jit_export/get_rmvpe.py new features: tools/jit_export/get_synthesizer.py optimize: tools/rvc_for_realtime.py * optimize: tools/jit_export/get_synthesizer.py fix bug: tools/jit_export/__init__.py * Fixed a bug caused by using half on the CPU: infer/lib/rmvpe.py Fixed a bug caused by using half on the CPU: tools/jit_export/__init__.py Fixed CIRCULAR IMPORT: tools/jit_export/get_rmvpe.py Fixed CIRCULAR IMPORT: tools/jit_export/get_synthesizer.py Fixed a bug caused by using half on the CPU: tools/rvc_for_realtime.py * Remove useless code: infer/lib/rmvpe.py * Delete gui_v1 copy.py * Delete .vscode/launch.json * Delete jit_export_test.py * Delete tools/rvc_for_realtime copy.py * Delete configs/config.json * Delete .gitignore * Fix exceptions caused by switching inference devices: infer/lib/rmvpe.py Fix exceptions caused by switching inference devices: tools/jit_export/__init__.py Fix exceptions caused by switching inference devices: tools/rvc_for_realtime.py * restore * replace(you can undo this commit) * remove debug_print --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * Fixed some bugs when exporting ONNX model (#1254) * fix import (#1280) * fix import * lint * 🎨 同步 locale (#1242) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix jit load and import issue (#1282) * fix jit model loading : infer/lib/rmvpe.py * modified: assets/hubert/.gitignore move file: assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth modified: assets/rmvpe/.gitignore move file: assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth fix import: gui_v1.py * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * Add input wav and delay time monitor for real-time gui (#1293) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * add input wav and delay time monitor --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> * Optimize latency using scripted jit (#1291) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * Optimize-latency-using-scripted: configs/config.py Optimize-latency-using-scripted: infer/lib/infer_pack/attentions.py Optimize-latency-using-scripted: infer/lib/infer_pack/commons.py Optimize-latency-using-scripted: infer/lib/infer_pack/models.py Optimize-latency-using-scripted: infer/lib/infer_pack/modules.py Optimize-latency-using-scripted: infer/lib/jit/__init__.py Optimize-latency-using-scripted: infer/lib/jit/get_hubert.py Optimize-latency-using-scripted: infer/lib/jit/get_rmvpe.py Optimize-latency-using-scripted: infer/lib/jit/get_synthesizer.py Optimize-latency-using-scripted: infer/lib/rmvpe.py Optimize-latency-using-scripted: tools/rvc_for_realtime.py * modified: infer/lib/infer_pack/models.py * fix some bug: configs/config.py fix some bug: infer/lib/infer_pack/models.py fix some bug: infer/lib/rmvpe.py * Fixed abnormal reference of logger in multiprocessing: infer/modules/train/train.py --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Format code (#1298) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * 🎨 同步 locale (#1299) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: optimize actions * feat(workflow): add sync dev * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: add jit options (#1303) Delete useless code: infer/lib/jit/get_synthesizer.py Optimized code: tools/rvc_for_realtime.py * Code refactor + re-design inference ui (#1304) * Code refacor + re-design inference ui * Fix tabname * i18n jp --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * feat: optimize actions * feat: optimize actions * Update README & en_US locale file (#1309) * critical: some bug fixes (#1322) * JIT acceleration switch does not support hot update * fix padding bug of rmvpe in torch-directml * fix padding bug of rmvpe in torch-directml * Fix STFT under torch_directml (#1330) * chore(format): run black on dev (#1318) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(i18n): sync locale on dev (#1317) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: allow for tta to be passed to uvr (#1361) * chore(format): run black on dev (#1373) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Added script for automatically download all needed models at install (#1366) * Delete modules.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * chore(i18n): sync locale on dev (#1377) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(format): run black on dev (#1376) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update IPEX library (#1362) * Update IPEX library * Update ipex index * chore(format): run black on dev (#1378) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com> Co-authored-by: Ftps <ftpsflandre@gmail.com> Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com> Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com> Co-authored-by: Rice Cake <gak141808@gmail.com> Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com> Co-authored-by: Dmitry <nda2911@yandex.ru> Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
fe166e7f3d
commit
e9dd11bddb
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,25 +29,32 @@ class TextEncoder256(nn.Module):
|
||||
p_dropout,
|
||||
f0=True,
|
||||
):
|
||||
super().__init__()
|
||||
super(TextEncoder256, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.emb_phone = nn.Linear(256, hidden_channels)
|
||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||
if f0 == True:
|
||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, phone, pitch, lengths):
|
||||
if pitch == None:
|
||||
def forward(
|
||||
self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor
|
||||
):
|
||||
if pitch is None:
|
||||
x = self.emb_phone(phone)
|
||||
else:
|
||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||||
@@ -75,25 +83,30 @@ class TextEncoder768(nn.Module):
|
||||
p_dropout,
|
||||
f0=True,
|
||||
):
|
||||
super().__init__()
|
||||
super(TextEncoder768, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.emb_phone = nn.Linear(768, hidden_channels)
|
||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||
if f0 == True:
|
||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, phone, pitch, lengths):
|
||||
if pitch == None:
|
||||
def forward(self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor):
|
||||
if pitch is None:
|
||||
x = self.emb_phone(phone)
|
||||
else:
|
||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||||
@@ -121,7 +134,7 @@ class ResidualCouplingBlock(nn.Module):
|
||||
n_flows=4,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
super(ResidualCouplingBlock, self).__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
@@ -145,19 +158,36 @@ class ResidualCouplingBlock(nn.Module):
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
for flow in self.flows[::-1]:
|
||||
x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for i in range(self.n_flows):
|
||||
self.flows[i * 2].remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for i in range(self.n_flows):
|
||||
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(
|
||||
@@ -170,7 +200,7 @@ class PosteriorEncoder(nn.Module):
|
||||
n_layers,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
super(PosteriorEncoder, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -189,7 +219,9 @@ class PosteriorEncoder(nn.Module):
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
@@ -203,6 +235,15 @@ class PosteriorEncoder(nn.Module):
|
||||
def remove_weight_norm(self):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
return self
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -252,7 +293,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@@ -273,6 +314,28 @@ class Generator(torch.nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
for l in self.resblocks:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
@@ -293,7 +356,7 @@ class SineGen(torch.nn.Module):
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
segment is always sin(torch.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -321,7 +384,7 @@ class SineGen(torch.nn.Module):
|
||||
uv = uv.float()
|
||||
return uv
|
||||
|
||||
def forward(self, f0, upp):
|
||||
def forward(self, f0: torch.Tensor, upp: int):
|
||||
"""sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
@@ -333,7 +396,7 @@ class SineGen(torch.nn.Module):
|
||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||||
# fundamental component
|
||||
f0_buf[:, :, 0] = f0[:, :, 0]
|
||||
for idx in np.arange(self.harmonic_num):
|
||||
for idx in range(self.harmonic_num):
|
||||
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
|
||||
idx + 2
|
||||
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
||||
@@ -347,12 +410,12 @@ class SineGen(torch.nn.Module):
|
||||
tmp_over_one *= upp
|
||||
tmp_over_one = F.interpolate(
|
||||
tmp_over_one.transpose(2, 1),
|
||||
scale_factor=upp,
|
||||
scale_factor=float(upp),
|
||||
mode="linear",
|
||||
align_corners=True,
|
||||
).transpose(2, 1)
|
||||
rad_values = F.interpolate(
|
||||
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||||
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
||||
).transpose(
|
||||
2, 1
|
||||
) #######
|
||||
@@ -361,12 +424,12 @@ class SineGen(torch.nn.Module):
|
||||
cumsum_shift = torch.zeros_like(rad_values)
|
||||
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
sine_waves = torch.sin(
|
||||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
||||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
|
||||
)
|
||||
sine_waves = sine_waves * self.sine_amp
|
||||
uv = self._f02uv(f0)
|
||||
uv = F.interpolate(
|
||||
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||||
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
||||
).transpose(2, 1)
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
@@ -414,18 +477,19 @@ class SourceModuleHnNSF(torch.nn.Module):
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
# self.ddtype:int = -1
|
||||
|
||||
def forward(self, x, upp=None):
|
||||
if hasattr(self, "ddtype") == False:
|
||||
self.ddtype = self.l_linear.weight.dtype
|
||||
def forward(self, x: torch.Tensor, upp: int = 1):
|
||||
# if self.ddtype ==-1:
|
||||
# self.ddtype = self.l_linear.weight.dtype
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
|
||||
# print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype)
|
||||
# if self.is_half:
|
||||
# sine_wavs = sine_wavs.half()
|
||||
# sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x)))
|
||||
# print(sine_wavs.dtype,self.ddtype)
|
||||
if sine_wavs.dtype != self.ddtype:
|
||||
sine_wavs = sine_wavs.to(self.ddtype)
|
||||
# if sine_wavs.dtype != self.l_linear.weight.dtype:
|
||||
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
return sine_merge, None, None # noise, uv
|
||||
|
||||
@@ -448,7 +512,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sr, harmonic_num=0, is_half=is_half
|
||||
)
|
||||
@@ -473,7 +537,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
)
|
||||
)
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
||||
stride_f0 = math.prod(upsample_rates[i + 1 :])
|
||||
self.noise_convs.append(
|
||||
Conv1d(
|
||||
1,
|
||||
@@ -500,27 +564,36 @@ class GeneratorNSF(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
self.upp = np.prod(upsample_rates)
|
||||
self.upp = math.prod(upsample_rates)
|
||||
|
||||
def forward(self, x, f0, g=None):
|
||||
self.lrelu_slope = modules.LRELU_SLOPE
|
||||
|
||||
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
|
||||
har_source, noi_source, uv = self.m_source(f0, self.upp)
|
||||
har_source = har_source.transpose(1, 2)
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
x_source = self.noise_convs[i](har_source)
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
# torch.jit.script() does not support direct indexing of torch modules
|
||||
# That's why I wrote this
|
||||
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
||||
if i < self.num_upsamples:
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = ups(x)
|
||||
x_source = noise_convs(har_source)
|
||||
x = x + x_source
|
||||
xs: Optional[torch.Tensor] = None
|
||||
l = [i * self.num_kernels + j for j in range(self.num_kernels)]
|
||||
for j, resblock in enumerate(self.resblocks):
|
||||
if j in l:
|
||||
if xs is None:
|
||||
xs = resblock(x)
|
||||
else:
|
||||
xs += resblock(x)
|
||||
# This assertion cannot be ignored! \
|
||||
# If ignored, it will cause torch.jit.script() compilation errors
|
||||
assert isinstance(xs, torch.Tensor)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
@@ -532,6 +605,27 @@ class GeneratorNSF(torch.nn.Module):
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
for hook in self.resblocks._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
|
||||
sr2sr = {
|
||||
"32k": 32000,
|
||||
@@ -563,8 +657,8 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
sr,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if type(sr) == type("strr"):
|
||||
super(SynthesizerTrnMs256NSFsid, self).__init__()
|
||||
if isinstance(sr, str):
|
||||
sr = sr2sr[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@@ -573,7 +667,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -591,7 +685,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.dec = GeneratorNSF(
|
||||
inter_channels,
|
||||
@@ -630,8 +724,42 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
pitchf: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
y_lengths: torch.Tensor,
|
||||
ds: Optional[torch.Tensor] = None,
|
||||
): # 这里ds是id,[bs,1]
|
||||
# print(1,pitch.shape)#[bs,t]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
@@ -647,15 +775,25 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
nsff0: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
nsff0 = nsff0[:, -head:]
|
||||
if rate is not None:
|
||||
assert isinstance(rate, torch.Tensor)
|
||||
head = int(z_p.shape[2] * (1 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, nsff0, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
@@ -684,8 +822,8 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
sr,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if type(sr) == type("strr"):
|
||||
super(SynthesizerTrnMs768NSFsid, self).__init__()
|
||||
if isinstance(sr, str):
|
||||
sr = sr2sr[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@@ -694,7 +832,7 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -712,7 +850,7 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.dec = GeneratorNSF(
|
||||
inter_channels,
|
||||
@@ -751,6 +889,33 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
||||
): # 这里ds是id,[bs,1]
|
||||
@@ -768,15 +933,24 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
nsff0: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
nsff0 = nsff0[:, -head:]
|
||||
if rate is not None:
|
||||
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, nsff0, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
@@ -805,7 +979,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
sr=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -813,7 +987,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -831,7 +1005,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
f0=False,
|
||||
)
|
||||
self.dec = Generator(
|
||||
@@ -869,6 +1043,33 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
@@ -880,14 +1081,22 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
if rate is not None:
|
||||
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
@@ -916,7 +1125,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
sr=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
super(self, SynthesizerTrnMs768NSFsid_nono).__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -924,7 +1133,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -942,7 +1151,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
f0=False,
|
||||
)
|
||||
self.dec = Generator(
|
||||
@@ -980,6 +1189,33 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
@@ -991,14 +1227,22 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
if rate is not None:
|
||||
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
Reference in New Issue
Block a user