Format code (#366)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-05-28 16:06:11 +00:00
committed by GitHub
parent e569477457
commit e435b3bb8a
6 changed files with 262 additions and 170 deletions

View File

@@ -3,9 +3,11 @@ from torch import nn
import torch.nn.functional as F
from uvr5_pack.lib_v5 import layers_new as layers
class BaseNet(nn.Module):
def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))):
class BaseNet(nn.Module):
def __init__(
self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))
):
super(BaseNet, self).__init__()
self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1)
self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1)
@@ -38,8 +40,8 @@ class BaseNet(nn.Module):
return h
class CascadedNet(nn.Module):
class CascadedNet(nn.Module):
def __init__(self, n_fft, nout=32, nout_lstm=128):
super(CascadedNet, self).__init__()
@@ -50,24 +52,30 @@ class CascadedNet(nn.Module):
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0),
)
self.stg1_high_band_net = BaseNet(
2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
)
self.stg2_low_band_net = nn.Sequential(
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2)
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0),
)
self.stg2_high_band_net = BaseNet(
nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
)
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
self.stg3_full_band_net = BaseNet(
3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
)
self.out = nn.Conv2d(nout, 2, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
def forward(self, x):
x = x[:, :, :self.max_bin]
x = x[:, :, : self.max_bin]
bandw = x.size()[2] // 2
l1_in = x[:, :, :bandw]
@@ -89,7 +97,7 @@ class CascadedNet(nn.Module):
mask = F.pad(
input=mask,
pad=(0, 0, 0, self.output_bin - mask.size()[2]),
mode='replicate'
mode="replicate",
)
if self.training:
@@ -98,7 +106,7 @@ class CascadedNet(nn.Module):
aux = F.pad(
input=aux,
pad=(0, 0, 0, self.output_bin - aux.size()[2]),
mode='replicate'
mode="replicate",
)
return mask, aux
else:
@@ -108,17 +116,17 @@ class CascadedNet(nn.Module):
mask = self.forward(x)
if self.offset > 0:
mask = mask[:, :, :, self.offset:-self.offset]
mask = mask[:, :, :, self.offset : -self.offset]
assert mask.size()[3] > 0
return mask
def predict(self, x,aggressiveness=None):
def predict(self, x, aggressiveness=None):
mask = self.forward(x)
pred_mag = x * mask
if self.offset > 0:
pred_mag = pred_mag[:, :, :, self.offset:-self.offset]
pred_mag = pred_mag[:, :, :, self.offset : -self.offset]
assert pred_mag.size()[3] > 0
return pred_mag