Feature: Intel ARC GPU support with IPEX (#1204)

* Initial Intel ARC support with IPEX

* Fix infer

* Fix train model

* Cleanup

* Cleanup

* Update README

* Make pylint happy

* Move dataloader fix to hijacks

* Fix torch.linalg.solve

* Fix SDP

* Add has_xpu to config.py

* Revert return_xpu fix
This commit is contained in:
Disty0
2023-09-09 07:00:29 +03:00
committed by GitHub
parent c761bda09a
commit 0c94f60093
13 changed files with 817 additions and 20 deletions

View File

@@ -17,6 +17,18 @@ n_gpus = len(hps.gpus.split("-"))
from random import randint, shuffle
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
from infer.modules.ipex.gradscaler import gradscaler_init
from torch.xpu.amp import autocast
GradScaler = gradscaler_init()
ipex_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception:
from torch.cuda.amp import GradScaler, autocast
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@@ -25,7 +37,6 @@ from time import time as ttime
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
@@ -185,7 +196,9 @@ def run(rank, n_gpus, hps):
)
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if torch.cuda.is_available():
if hasattr(torch, "xpu") and torch.xpu.is_available():
pass
elif torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank])
net_d = DDP(net_d, device_ids=[rank])
else:
@@ -212,19 +225,33 @@ def run(rank, n_gpus, hps):
if hps.pretrainG != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG))
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
if hasattr(net_g, "module"):
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
else:
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
if hps.pretrainD != "":
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD))
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
if hasattr(net_d, "module"):
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
else:
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2