mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-01-20 11:00:23 +00:00
chore(sync): merge dev into main (#1379)
* Optimize latency (#1259) * add attribute: configs/config.py Optimize latency: tools/rvc_for_realtime.py * new file: assets/Synthesizer_inputs.pth * fix: configs/config.py fix: tools/rvc_for_realtime.py * fix bug: infer/lib/infer_pack/models.py * new file: assets/hubert_inputs.pth new file: assets/rmvpe_inputs.pth modified: configs/config.py new features: infer/lib/rmvpe.py new features: tools/jit_export/__init__.py new features: tools/jit_export/get_hubert.py new features: tools/jit_export/get_rmvpe.py new features: tools/jit_export/get_synthesizer.py optimize: tools/rvc_for_realtime.py * optimize: tools/jit_export/get_synthesizer.py fix bug: tools/jit_export/__init__.py * Fixed a bug caused by using half on the CPU: infer/lib/rmvpe.py Fixed a bug caused by using half on the CPU: tools/jit_export/__init__.py Fixed CIRCULAR IMPORT: tools/jit_export/get_rmvpe.py Fixed CIRCULAR IMPORT: tools/jit_export/get_synthesizer.py Fixed a bug caused by using half on the CPU: tools/rvc_for_realtime.py * Remove useless code: infer/lib/rmvpe.py * Delete gui_v1 copy.py * Delete .vscode/launch.json * Delete jit_export_test.py * Delete tools/rvc_for_realtime copy.py * Delete configs/config.json * Delete .gitignore * Fix exceptions caused by switching inference devices: infer/lib/rmvpe.py Fix exceptions caused by switching inference devices: tools/jit_export/__init__.py Fix exceptions caused by switching inference devices: tools/rvc_for_realtime.py * restore * replace(you can undo this commit) * remove debug_print --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * Fixed some bugs when exporting ONNX model (#1254) * fix import (#1280) * fix import * lint * 🎨 同步 locale (#1242) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix jit load and import issue (#1282) * fix jit model loading : infer/lib/rmvpe.py * modified: assets/hubert/.gitignore move file: assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth modified: assets/rmvpe/.gitignore move file: assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth fix import: gui_v1.py * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * Add input wav and delay time monitor for real-time gui (#1293) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * add input wav and delay time monitor --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> * Optimize latency using scripted jit (#1291) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * Optimize-latency-using-scripted: configs/config.py Optimize-latency-using-scripted: infer/lib/infer_pack/attentions.py Optimize-latency-using-scripted: infer/lib/infer_pack/commons.py Optimize-latency-using-scripted: infer/lib/infer_pack/models.py Optimize-latency-using-scripted: infer/lib/infer_pack/modules.py Optimize-latency-using-scripted: infer/lib/jit/__init__.py Optimize-latency-using-scripted: infer/lib/jit/get_hubert.py Optimize-latency-using-scripted: infer/lib/jit/get_rmvpe.py Optimize-latency-using-scripted: infer/lib/jit/get_synthesizer.py Optimize-latency-using-scripted: infer/lib/rmvpe.py Optimize-latency-using-scripted: tools/rvc_for_realtime.py * modified: infer/lib/infer_pack/models.py * fix some bug: configs/config.py fix some bug: infer/lib/infer_pack/models.py fix some bug: infer/lib/rmvpe.py * Fixed abnormal reference of logger in multiprocessing: infer/modules/train/train.py --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Format code (#1298) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * 🎨 同步 locale (#1299) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: optimize actions * feat(workflow): add sync dev * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: add jit options (#1303) Delete useless code: infer/lib/jit/get_synthesizer.py Optimized code: tools/rvc_for_realtime.py * Code refactor + re-design inference ui (#1304) * Code refacor + re-design inference ui * Fix tabname * i18n jp --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * feat: optimize actions * feat: optimize actions * Update README & en_US locale file (#1309) * critical: some bug fixes (#1322) * JIT acceleration switch does not support hot update * fix padding bug of rmvpe in torch-directml * fix padding bug of rmvpe in torch-directml * Fix STFT under torch_directml (#1330) * chore(format): run black on dev (#1318) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(i18n): sync locale on dev (#1317) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: allow for tta to be passed to uvr (#1361) * chore(format): run black on dev (#1373) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Added script for automatically download all needed models at install (#1366) * Delete modules.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * chore(i18n): sync locale on dev (#1377) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(format): run black on dev (#1376) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update IPEX library (#1362) * Update IPEX library * Update ipex index * chore(format): run black on dev (#1378) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com> Co-authored-by: Ftps <ftpsflandre@gmail.com> Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com> Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com> Co-authored-by: Rice Cake <gak141808@gmail.com> Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com> Co-authored-by: Dmitry <nda2911@yandex.ru> Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
fe166e7f3d
commit
e9dd11bddb
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -22,11 +23,11 @@ class Encoder(nn.Module):
|
||||
window_size=10,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.n_layers = int(n_layers)
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
@@ -61,14 +62,17 @@ class Encoder(nn.Module):
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
zippep = zip(
|
||||
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
|
||||
)
|
||||
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep:
|
||||
y = attn_layers(x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
x = norm_layers_1(x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = ffn_layers(x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = norm_layers_2(x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
@@ -86,7 +90,7 @@ class Decoder(nn.Module):
|
||||
proximal_init=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
super(Decoder, self).__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
@@ -172,7 +176,7 @@ class MultiHeadAttention(nn.Module):
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
@@ -213,19 +217,28 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
def forward(
|
||||
self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
|
||||
):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
x, _ = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
def attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
b, d, t_s = key.size()
|
||||
t_t = query.size(2)
|
||||
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
@@ -292,16 +305,17 @@ class MultiHeadAttention(nn.Module):
|
||||
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
def _get_relative_embeddings(self, relative_embeddings, length: int):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.window_size + 1), 0)
|
||||
pad_length: int = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
# commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
[0, 0, pad_length, pad_length, 0, 0],
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
@@ -317,12 +331,18 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Concat columns of pad to shift from relative to absolute indexing.
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
||||
x = F.pad(
|
||||
x,
|
||||
# commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
)
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
x_flat,
|
||||
# commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]])
|
||||
[0, int(length) - 1, 0, 0, 0, 0],
|
||||
)
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
@@ -339,15 +359,21 @@ class MultiHeadAttention(nn.Module):
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
x,
|
||||
# commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]])
|
||||
[0, int(length) - 1, 0, 0, 0, 0, 0, 0],
|
||||
)
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
x_flat = x.view([batch, heads, int(length**2) + int(length * (length - 1))])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_flat = F.pad(
|
||||
x_flat,
|
||||
# commons.convert_pad_shape([[0, 0], [0, 0], [int(length), 0]])
|
||||
[length, 0, 0, 0, 0, 0],
|
||||
)
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
def _attention_bias_proximal(self, length: int):
|
||||
"""Bias for self-attention to encourage attention to close positions.
|
||||
Args:
|
||||
length: an integer scalar.
|
||||
@@ -367,10 +393,10 @@ class FFN(nn.Module):
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=0.0,
|
||||
activation=None,
|
||||
activation: str = None,
|
||||
causal=False,
|
||||
):
|
||||
super().__init__()
|
||||
super(FFN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
@@ -378,40 +404,56 @@ class FFN(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
self.is_activation = True if activation == "gelu" else False
|
||||
# if causal:
|
||||
# self.padding = self._causal_padding
|
||||
# else:
|
||||
# self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
if self.activation == "gelu":
|
||||
def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.causal:
|
||||
padding = self._causal_padding(x * x_mask)
|
||||
else:
|
||||
padding = self._same_padding(x * x_mask)
|
||||
return padding
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
|
||||
x = self.conv_1(self.padding(x, x_mask))
|
||||
if self.is_activation:
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
|
||||
x = self.conv_2(self.padding(x, x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
pad_l: int = self.kernel_size - 1
|
||||
pad_r: int = 0
|
||||
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(
|
||||
x,
|
||||
# commons.convert_pad_shape(padding)
|
||||
[pad_l, pad_r, 0, 0, 0, 0],
|
||||
)
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
pad_l: int = (self.kernel_size - 1) // 2
|
||||
pad_r: int = self.kernel_size // 2
|
||||
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(
|
||||
x,
|
||||
# commons.convert_pad_shape(padding)
|
||||
[pad_l, pad_r, 0, 0, 0, 0],
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import List, Optional
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
@@ -16,10 +17,10 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
# def convert_pad_shape(pad_shape):
|
||||
# l = pad_shape[::-1]
|
||||
# pad_shape = [item for sublist in l for item in sublist]
|
||||
# return pad_shape
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
@@ -113,10 +114,14 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
# def convert_pad_shape(pad_shape):
|
||||
# l = pad_shape[::-1]
|
||||
# pad_shape = [item for sublist in l for item in sublist]
|
||||
# return pad_shape
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
|
||||
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
@@ -124,7 +129,7 @@ def shift_1d(x):
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,25 +29,32 @@ class TextEncoder256(nn.Module):
|
||||
p_dropout,
|
||||
f0=True,
|
||||
):
|
||||
super().__init__()
|
||||
super(TextEncoder256, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.emb_phone = nn.Linear(256, hidden_channels)
|
||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||
if f0 == True:
|
||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, phone, pitch, lengths):
|
||||
if pitch == None:
|
||||
def forward(
|
||||
self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor
|
||||
):
|
||||
if pitch is None:
|
||||
x = self.emb_phone(phone)
|
||||
else:
|
||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||||
@@ -75,25 +83,30 @@ class TextEncoder768(nn.Module):
|
||||
p_dropout,
|
||||
f0=True,
|
||||
):
|
||||
super().__init__()
|
||||
super(TextEncoder768, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.emb_phone = nn.Linear(768, hidden_channels)
|
||||
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
||||
if f0 == True:
|
||||
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, phone, pitch, lengths):
|
||||
if pitch == None:
|
||||
def forward(self, phone: torch.Tensor, pitch: torch.Tensor, lengths: torch.Tensor):
|
||||
if pitch is None:
|
||||
x = self.emb_phone(phone)
|
||||
else:
|
||||
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
||||
@@ -121,7 +134,7 @@ class ResidualCouplingBlock(nn.Module):
|
||||
n_flows=4,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
super(ResidualCouplingBlock, self).__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
@@ -145,19 +158,36 @@ class ResidualCouplingBlock(nn.Module):
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
for flow in self.flows[::-1]:
|
||||
x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for i in range(self.n_flows):
|
||||
self.flows[i * 2].remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for i in range(self.n_flows):
|
||||
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(
|
||||
@@ -170,7 +200,7 @@ class PosteriorEncoder(nn.Module):
|
||||
n_layers,
|
||||
gin_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
super(PosteriorEncoder, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -189,7 +219,9 @@ class PosteriorEncoder(nn.Module):
|
||||
)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
@@ -203,6 +235,15 @@ class PosteriorEncoder(nn.Module):
|
||||
def remove_weight_norm(self):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
return self
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -252,7 +293,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@@ -273,6 +314,28 @@ class Generator(torch.nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
for l in self.resblocks:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
@@ -293,7 +356,7 @@ class SineGen(torch.nn.Module):
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
segment is always sin(torch.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -321,7 +384,7 @@ class SineGen(torch.nn.Module):
|
||||
uv = uv.float()
|
||||
return uv
|
||||
|
||||
def forward(self, f0, upp):
|
||||
def forward(self, f0: torch.Tensor, upp: int):
|
||||
"""sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
@@ -333,7 +396,7 @@ class SineGen(torch.nn.Module):
|
||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
||||
# fundamental component
|
||||
f0_buf[:, :, 0] = f0[:, :, 0]
|
||||
for idx in np.arange(self.harmonic_num):
|
||||
for idx in range(self.harmonic_num):
|
||||
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
|
||||
idx + 2
|
||||
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
||||
@@ -347,12 +410,12 @@ class SineGen(torch.nn.Module):
|
||||
tmp_over_one *= upp
|
||||
tmp_over_one = F.interpolate(
|
||||
tmp_over_one.transpose(2, 1),
|
||||
scale_factor=upp,
|
||||
scale_factor=float(upp),
|
||||
mode="linear",
|
||||
align_corners=True,
|
||||
).transpose(2, 1)
|
||||
rad_values = F.interpolate(
|
||||
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||||
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
||||
).transpose(
|
||||
2, 1
|
||||
) #######
|
||||
@@ -361,12 +424,12 @@ class SineGen(torch.nn.Module):
|
||||
cumsum_shift = torch.zeros_like(rad_values)
|
||||
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
sine_waves = torch.sin(
|
||||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
||||
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
|
||||
)
|
||||
sine_waves = sine_waves * self.sine_amp
|
||||
uv = self._f02uv(f0)
|
||||
uv = F.interpolate(
|
||||
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
|
||||
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
|
||||
).transpose(2, 1)
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
@@ -414,18 +477,19 @@ class SourceModuleHnNSF(torch.nn.Module):
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
# self.ddtype:int = -1
|
||||
|
||||
def forward(self, x, upp=None):
|
||||
if hasattr(self, "ddtype") == False:
|
||||
self.ddtype = self.l_linear.weight.dtype
|
||||
def forward(self, x: torch.Tensor, upp: int = 1):
|
||||
# if self.ddtype ==-1:
|
||||
# self.ddtype = self.l_linear.weight.dtype
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
|
||||
# print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype)
|
||||
# if self.is_half:
|
||||
# sine_wavs = sine_wavs.half()
|
||||
# sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x)))
|
||||
# print(sine_wavs.dtype,self.ddtype)
|
||||
if sine_wavs.dtype != self.ddtype:
|
||||
sine_wavs = sine_wavs.to(self.ddtype)
|
||||
# if sine_wavs.dtype != self.l_linear.weight.dtype:
|
||||
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
return sine_merge, None, None # noise, uv
|
||||
|
||||
@@ -448,7 +512,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sr, harmonic_num=0, is_half=is_half
|
||||
)
|
||||
@@ -473,7 +537,7 @@ class GeneratorNSF(torch.nn.Module):
|
||||
)
|
||||
)
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
||||
stride_f0 = math.prod(upsample_rates[i + 1 :])
|
||||
self.noise_convs.append(
|
||||
Conv1d(
|
||||
1,
|
||||
@@ -500,27 +564,36 @@ class GeneratorNSF(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
self.upp = np.prod(upsample_rates)
|
||||
self.upp = math.prod(upsample_rates)
|
||||
|
||||
def forward(self, x, f0, g=None):
|
||||
self.lrelu_slope = modules.LRELU_SLOPE
|
||||
|
||||
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
|
||||
har_source, noi_source, uv = self.m_source(f0, self.upp)
|
||||
har_source = har_source.transpose(1, 2)
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
x_source = self.noise_convs[i](har_source)
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
# torch.jit.script() does not support direct indexing of torch modules
|
||||
# That's why I wrote this
|
||||
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
||||
if i < self.num_upsamples:
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = ups(x)
|
||||
x_source = noise_convs(har_source)
|
||||
x = x + x_source
|
||||
xs: Optional[torch.Tensor] = None
|
||||
l = [i * self.num_kernels + j for j in range(self.num_kernels)]
|
||||
for j, resblock in enumerate(self.resblocks):
|
||||
if j in l:
|
||||
if xs is None:
|
||||
xs = resblock(x)
|
||||
else:
|
||||
xs += resblock(x)
|
||||
# This assertion cannot be ignored! \
|
||||
# If ignored, it will cause torch.jit.script() compilation errors
|
||||
assert isinstance(xs, torch.Tensor)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
@@ -532,6 +605,27 @@ class GeneratorNSF(torch.nn.Module):
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.ups:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
for hook in self.resblocks._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
|
||||
sr2sr = {
|
||||
"32k": 32000,
|
||||
@@ -563,8 +657,8 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
sr,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if type(sr) == type("strr"):
|
||||
super(SynthesizerTrnMs256NSFsid, self).__init__()
|
||||
if isinstance(sr, str):
|
||||
sr = sr2sr[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@@ -573,7 +667,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -591,7 +685,7 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.dec = GeneratorNSF(
|
||||
inter_channels,
|
||||
@@ -630,8 +724,42 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
pitchf: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
y_lengths: torch.Tensor,
|
||||
ds: Optional[torch.Tensor] = None,
|
||||
): # 这里ds是id,[bs,1]
|
||||
# print(1,pitch.shape)#[bs,t]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
@@ -647,15 +775,25 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
nsff0: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
nsff0 = nsff0[:, -head:]
|
||||
if rate is not None:
|
||||
assert isinstance(rate, torch.Tensor)
|
||||
head = int(z_p.shape[2] * (1 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, nsff0, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
@@ -684,8 +822,8 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
sr,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if type(sr) == type("strr"):
|
||||
super(SynthesizerTrnMs768NSFsid, self).__init__()
|
||||
if isinstance(sr, str):
|
||||
sr = sr2sr[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@@ -694,7 +832,7 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -712,7 +850,7 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
)
|
||||
self.dec = GeneratorNSF(
|
||||
inter_channels,
|
||||
@@ -751,6 +889,33 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
||||
): # 这里ds是id,[bs,1]
|
||||
@@ -768,15 +933,24 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
nsff0: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
nsff0 = nsff0[:, -head:]
|
||||
if rate is not None:
|
||||
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, nsff0, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
@@ -805,7 +979,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
sr=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -813,7 +987,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -831,7 +1005,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
f0=False,
|
||||
)
|
||||
self.dec = Generator(
|
||||
@@ -869,6 +1043,33 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
@@ -880,14 +1081,22 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
if rate is not None:
|
||||
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
@@ -916,7 +1125,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
sr=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
super(self, SynthesizerTrnMs768NSFsid_nono).__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -924,7 +1133,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
@@ -942,7 +1151,7 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
float(p_dropout),
|
||||
f0=False,
|
||||
)
|
||||
self.dec = Generator(
|
||||
@@ -980,6 +1189,33 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
self.flow.remove_weight_norm()
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
@@ -991,14 +1227,22 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, phone, phone_lengths, sid, rate=None):
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
rate: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
if rate:
|
||||
head = int(z_p.shape[2] * rate)
|
||||
z_p = z_p[:, :, -head:]
|
||||
x_mask = x_mask[:, :, -head:]
|
||||
if rate is not None:
|
||||
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
||||
z_p = z_p[:, :, head:]
|
||||
x_mask = x_mask[:, :, head:]
|
||||
nsff0 = nsff0[:, head:]
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(z * x_mask, g=g)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@@ -551,7 +551,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
gin_channels,
|
||||
sr,
|
||||
version,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if type(sr) == type("strr"):
|
||||
@@ -621,10 +621,7 @@ class SynthesizerTrnMsNSFsidM(nn.Module):
|
||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||||
self.speaker_map = None
|
||||
logger.debug(
|
||||
"gin_channels: "
|
||||
+ gin_channels
|
||||
+ ", self.spk_embed_dim: "
|
||||
+ self.spk_embed_dim
|
||||
f"gin_channels: {gin_channels}, self.spk_embed_dim: {self.spk_embed_dim}"
|
||||
)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@@ -18,7 +19,7 @@ LRELU_SLOPE = 0.1
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
super(LayerNorm, self).__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
@@ -41,13 +42,13 @@ class ConvReluNorm(nn.Module):
|
||||
n_layers,
|
||||
p_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
super(ConvReluNorm, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
@@ -58,7 +59,7 @@ class ConvReluNorm(nn.Module):
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(float(p_dropout)))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
@@ -89,13 +90,13 @@ class DDSConv(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
||||
super().__init__()
|
||||
super(DDSConv, self).__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.drop = nn.Dropout(float(p_dropout))
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
@@ -117,7 +118,7 @@ class DDSConv(nn.Module):
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
def forward(self, x, x_mask, g: Optional[torch.Tensor] = None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
@@ -149,11 +150,11 @@ class WN(torch.nn.Module):
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
self.p_dropout = float(p_dropout)
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.drop = nn.Dropout(float(p_dropout))
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
@@ -184,15 +185,19 @@ class WN(torch.nn.Module):
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
for i, (in_layer, res_skip_layer) in enumerate(
|
||||
zip(self.in_layers, self.res_skip_layers)
|
||||
):
|
||||
x_in = in_layer(x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
@@ -202,7 +207,7 @@ class WN(torch.nn.Module):
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
res_skip_acts = res_skip_layer(acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||
x = (x + res_acts) * x_mask
|
||||
@@ -219,6 +224,30 @@ class WN(torch.nn.Module):
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
if self.gin_channels != 0:
|
||||
for hook in self.cond_layer._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
@@ -294,14 +323,15 @@ class ResBlock1(torch.nn.Module):
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
self.lrelu_slope = LRELU_SLOPE
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = F.leaky_relu(x, self.lrelu_slope)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = F.leaky_relu(xt, self.lrelu_slope)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c2(xt)
|
||||
@@ -316,6 +346,23 @@ class ResBlock1(torch.nn.Module):
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.convs1:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
@@ -345,10 +392,11 @@ class ResBlock2(torch.nn.Module):
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
self.lrelu_slope = LRELU_SLOPE
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
def forward(self, x, x_mask: Optional[torch.Tensor] = None):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = F.leaky_relu(x, self.lrelu_slope)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c(xt)
|
||||
@@ -361,9 +409,25 @@ class ResBlock2(torch.nn.Module):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for l in self.convs:
|
||||
for hook in l._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
return self
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
@@ -374,18 +438,27 @@ class Log(nn.Module):
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
# torch.jit.script() Compiled functions \
|
||||
# can't take variable number of arguments or \
|
||||
# use keyword-only arguments with defaults
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
return x, torch.zeros([1], device=x.device)
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
super(ElementwiseAffine, self).__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||
@@ -414,7 +487,7 @@ class ResidualCouplingLayer(nn.Module):
|
||||
mean_only=False,
|
||||
):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
super(ResidualCouplingLayer, self).__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
@@ -429,14 +502,20 @@ class ResidualCouplingLayer(nn.Module):
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=p_dropout,
|
||||
p_dropout=float(p_dropout),
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
@@ -455,11 +534,20 @@ class ResidualCouplingLayer(nn.Module):
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
return x, torch.zeros([1])
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.enc.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.enc._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc)
|
||||
return self
|
||||
|
||||
|
||||
class ConvFlow(nn.Module):
|
||||
def __init__(
|
||||
@@ -471,7 +559,7 @@ class ConvFlow(nn.Module):
|
||||
num_bins=10,
|
||||
tail_bound=5.0,
|
||||
):
|
||||
super().__init__()
|
||||
super(ConvFlow, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
@@ -488,7 +576,13 @@ class ConvFlow(nn.Module):
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
reverse=False,
|
||||
):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0)
|
||||
h = self.convs(h, x_mask, g=g)
|
||||
|
||||
Reference in New Issue
Block a user