Format code (#1193)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-09-14 09:34:30 +09:00
committed by GitHub
parent 72a18e66b6
commit a6456f6d46
15 changed files with 562 additions and 237 deletions

View File

@@ -15,6 +15,7 @@ from infer.lib.infer_pack.commons import get_padding, init_weights
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder256(nn.Module):
def __init__(
self,
@@ -1158,7 +1159,9 @@ class DiscriminatorP(torch.nn.Module):
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16)
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
dtype=torch.bfloat16
)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad