Format code (#366)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-05-28 16:06:11 +00:00
committed by GitHub
parent e569477457
commit e435b3bb8a
6 changed files with 262 additions and 170 deletions

View File

@@ -4,27 +4,29 @@ import torch.nn.functional as F
from uvr5_pack.lib_v5 import spec_utils
class Conv2DBNActiv(nn.Module):
class Conv2DBNActiv(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
super(Conv2DBNActiv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
nin, nout,
nin,
nout,
kernel_size=ksize,
stride=stride,
padding=pad,
dilation=dilation,
bias=False),
bias=False,
),
nn.BatchNorm2d(nout),
activ()
activ(),
)
def __call__(self, x):
return self.conv(x)
class Encoder(nn.Module):
class Encoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
super(Encoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
@@ -38,15 +40,16 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
def __init__(
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def __call__(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
if skip is not None:
skip = spec_utils.crop_center(skip, x)
@@ -62,12 +65,11 @@ class Decoder(nn.Module):
class ASPPModule(nn.Module):
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
super(ASPPModule, self).__init__()
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d((1, None)),
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
)
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
self.conv3 = Conv2DBNActiv(
@@ -84,7 +86,9 @@ class ASPPModule(nn.Module):
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
feat1 = F.interpolate(
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
@@ -99,19 +103,14 @@ class ASPPModule(nn.Module):
class LSTMModule(nn.Module):
def __init__(self, nin_conv, nin_lstm, nout_lstm):
super(LSTMModule, self).__init__()
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
self.lstm = nn.LSTM(
input_size=nin_lstm,
hidden_size=nout_lstm // 2,
bidirectional=True
input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True
)
self.dense = nn.Sequential(
nn.Linear(nout_lstm, nin_lstm),
nn.BatchNorm1d(nin_lstm),
nn.ReLU()
nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
)
def forward(self, x):