Feature: Intel ARC GPU support with IPEX (#1204)

* Initial Intel ARC support with IPEX

* Fix infer

* Fix train model

* Cleanup

* Cleanup

* Update README

* Make pylint happy

* Move dataloader fix to hijacks

* Fix torch.linalg.solve

* Fix SDP

* Add has_xpu to config.py

* Revert return_xpu fix
This commit is contained in:
Disty0
2023-09-09 07:00:29 +03:00
committed by GitHub
parent c761bda09a
commit 0c94f60093
13 changed files with 817 additions and 20 deletions

View File

@@ -13,6 +13,7 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from infer.lib.infer_pack import attentions, commons, modules
from infer.lib.infer_pack.commons import get_padding, init_weights
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder256(nn.Module):
def __init__(
@@ -1156,7 +1157,10 @@ class DiscriminatorP(torch.nn.Module):
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)