Format code (#1193)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-09-14 09:34:30 +09:00
committed by GitHub
parent 72a18e66b6
commit a6456f6d46
15 changed files with 562 additions and 237 deletions

View File

@@ -3,38 +3,49 @@ import numpy as np
import av
from io import BytesIO
def wav2(i, o, format):
inp = av.open(i, 'rb')
if format == "m4a": format = "mp4"
out = av.open(o, 'wb', format=format)
if format == "ogg": format = "libvorbis"
if format == "mp4": format = "aac"
inp = av.open(i, "rb")
if format == "m4a":
format = "mp4"
out = av.open(o, "wb", format=format)
if format == "ogg":
format = "libvorbis"
if format == "mp4":
format = "aac"
ostream = out.add_stream(format)
for frame in inp.decode(audio=0):
for p in ostream.encode(frame): out.mux(p)
for p in ostream.encode(frame):
out.mux(p)
for p in ostream.encode(None): out.mux(p)
for p in ostream.encode(None):
out.mux(p)
out.close()
inp.close()
def audio2(i, o, format, sr):
inp = av.open(i, 'rb')
out = av.open(o, 'wb', format=format)
if format == "ogg": format = "libvorbis"
if format == "f32le": format = "pcm_f32le"
inp = av.open(i, "rb")
out = av.open(o, "wb", format=format)
if format == "ogg":
format = "libvorbis"
if format == "f32le":
format = "pcm_f32le"
ostream = out.add_stream(format, channels=1)
ostream.sample_rate = sr
for frame in inp.decode(audio=0):
for p in ostream.encode(frame): out.mux(p)
for p in ostream.encode(frame):
out.mux(p)
out.close()
inp.close()
def load_audio(file, sr):
try:
file = (

View File

@@ -15,6 +15,7 @@ from infer.lib.infer_pack.commons import get_padding, init_weights
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder256(nn.Module):
def __init__(
self,
@@ -1158,7 +1159,9 @@ class DiscriminatorP(torch.nn.Module):
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16)
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
dtype=torch.bfloat16
)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad

View File

@@ -2,11 +2,14 @@ import pdb, os
import numpy as np
import torch
try:
#Fix "Torch not compiled with CUDA enabled"
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# Fix "Torch not compiled with CUDA enabled"
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
ipex_init()
except Exception:
pass

View File

@@ -2,15 +2,16 @@ import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
@@ -91,11 +92,11 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
@@ -111,9 +112,11 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
torch.cuda.reset_accumulated_memory_stats = (
torch.xpu.reset_accumulated_memory_stats
)
#RNG:
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
@@ -124,35 +127,44 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
# AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
from .gradscaler import (
gradscaler_init,
) # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [
(
torch.xpu.get_device_properties(device).total_memory
- torch.xpu.memory_allocated(device)
),
torch.xpu.get_device_properties(device).total_memory,
]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_capability = lambda *args, **kwargs: [11, 7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None

View File

@@ -1,22 +1,32 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = (
input.shape[0],
input.shape[1],
mat2.shape[2],
)
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
block_size = (
(batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply
) # MB
split_slice_size = batch_size_attention
if block_size >= 4000:
do_split = True
#Find something divisible with the input_tokens
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
# Find something divisible with the input_tokens
while (
(split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply
) > 4000:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
@@ -24,12 +34,16 @@ def torch_bmm(input, mat2, *, out=None):
else:
do_split = False
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
split_block_size = (
(split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply
) # MB
split_2_slice_size = input_tokens
if split_block_size >= 4000:
do_split_2 = True
#Find something divisible with the input_tokens
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
# Find something divisible with the input_tokens
while (
(split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply
) > 4000:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
@@ -38,40 +52,61 @@ def torch_bmm(input, mat2, *, out=None):
do_split_2 = False
if do_split:
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
hidden_states = torch.zeros(
input.shape[0],
input.shape[1],
mat2.shape[2],
device=input.device,
dtype=input.dtype,
)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
for i2 in range(
input_tokens // split_2_slice_size
): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
hidden_states[
start_idx:end_idx, start_idx_2:end_idx_2
] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
out=out,
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
input[start_idx:end_idx], mat2[start_idx:end_idx], out=out
)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
block_size = (
(shape_one * batch_size_attention * query_tokens * shape_four)
/ 1024
* block_multiply
) # MB
split_slice_size = batch_size_attention
if block_size >= 4000:
do_split = True
#Find something divisible with the shape_one
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
# Find something divisible with the shape_one
while (
(shape_one * split_slice_size * query_tokens * shape_four)
/ 1024
* block_multiply
) > 4000:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
@@ -79,12 +114,20 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
else:
do_split = False
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
split_block_size = (
(shape_one * split_slice_size * query_tokens * shape_four)
/ 1024
* block_multiply
) # MB
split_2_slice_size = query_tokens
if split_block_size >= 4000:
do_split_2 = True
#Find something divisible with the batch_size_attention
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
# Find something divisible with the batch_size_attention
while (
(shape_one * split_slice_size * split_2_slice_size * shape_four)
/ 1024
* block_multiply
) > 4000:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
@@ -98,31 +141,49 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
for i2 in range(
query_tokens // split_2_slice_size
): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
hidden_states[
:, start_idx:end_idx, start_idx_2:end_idx_2
] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
hidden_states[
:, start_idx:end_idx
] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
attn_mask=attn_mask[:, start_idx:end_idx]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
# ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1,15 +1,20 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
_refresh_per_optimizer_state = (
ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
def _unscale_grads_(
self, optimizer, inv_scale, found_inf, allow_fp16
): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
@@ -43,9 +48,9 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
to_unscale
)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
@@ -57,6 +62,7 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
@@ -87,7 +93,7 @@ def unscale_(self, optimizer):
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
@@ -96,16 +102,17 @@ def unscale_(self, optimizer):
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
inv_scale = (
self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
@@ -171,6 +178,7 @@ def update(self, new_scale=None):
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_

View File

@@ -1,45 +1,59 @@
import contextlib
import importlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
func_path = orig_func.split(".")
for i in range(len(func_path) - 1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
resolved_obj = importlib.import_module(".".join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
setattr(
resolved_obj,
func_path[-1],
lambda *args, **kwargs: self(*args, **kwargs),
)
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
if (
torch.utils.data._utils is None
or torch.utils.data._utils.python_exit_status is True
or torch.utils.data._utils.python_exit_status is None
):
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
if hasattr(self, "_pin_memory_thread"):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
@@ -49,145 +63,292 @@ def _shutdown_workers(self):
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
class DummyDataParallel(
torch.nn.Module
): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(
cls, module, device_ids=None, output_device=None, dim=0
): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
return bool(
(isinstance(device, torch.device) and device.type == "cuda")
or (isinstance(device, str) and "cuda" in device)
or isinstance(device, int)
)
def return_xpu(device):
return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
return (
f"xpu:{device[-1]}"
if isinstance(device, str) and ":" in device
else f"xpu:{device}"
if isinstance(device, int)
else torch.device("xpu")
if isinstance(device, torch.device)
else "xpu"
)
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
if len(tensor) == 3 and (
tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype
):
return original_torch_cat(
[tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)],
*args,
**kwargs,
)
else:
return original_torch_cat(tensor, *args, **kwargs)
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
def interpolate(
tensor,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
antialias=False,
): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
return original_interpolate(
tensor.to("cpu", dtype=torch.float32),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
return original_interpolate(
tensor,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(
return_device
)
else:
return original_linalg_solve(A, B, *args, **kwargs)
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc(
"torch.Tensor.to",
lambda orig_func, self, device=None, *args, **kwargs: orig_func(
self, return_xpu(device), *args, **kwargs
),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device),
)
CondFunc(
"torch.Tensor.cuda",
lambda orig_func, self, device=None, *args, **kwargs: orig_func(
self, return_xpu(device), *args, **kwargs
),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device),
)
CondFunc(
"torch.empty",
lambda orig_func, *args, device=None, **kwargs: orig_func(
*args, device=return_xpu(device), **kwargs
),
lambda orig_func, *args, device=None, **kwargs: check_device(device),
)
CondFunc(
"torch.load",
lambda orig_func, *args, map_location=None, **kwargs: orig_func(
*args, return_xpu(map_location), **kwargs
),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None
or check_device(map_location),
)
CondFunc(
"torch.randn",
lambda orig_func, *args, device=None, **kwargs: orig_func(
*args, device=return_xpu(device), **kwargs
),
lambda orig_func, *args, device=None, **kwargs: check_device(device),
)
CondFunc(
"torch.ones",
lambda orig_func, *args, device=None, **kwargs: orig_func(
*args, device=return_xpu(device), **kwargs
),
lambda orig_func, *args, device=None, **kwargs: check_device(device),
)
CondFunc(
"torch.zeros",
lambda orig_func, *args, device=None, **kwargs: orig_func(
*args, device=return_xpu(device), **kwargs
),
lambda orig_func, *args, device=None, **kwargs: check_device(device),
)
CondFunc(
"torch.tensor",
lambda orig_func, *args, device=None, **kwargs: orig_func(
*args, device=return_xpu(device), **kwargs
),
lambda orig_func, *args, device=None, **kwargs: check_device(device),
)
CondFunc(
"torch.linspace",
lambda orig_func, *args, device=None, **kwargs: orig_func(
*args, device=return_xpu(device), **kwargs
),
lambda orig_func, *args, device=None, **kwargs: check_device(device),
)
CondFunc('torch.Generator',
CondFunc(
"torch.Generator",
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
lambda orig_func, device=None: device is not None
and device != torch.device("cpu")
and device != "cpu",
)
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc(
"torch.batch_norm",
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
input,
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device),
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device),
*args,
**kwargs,
),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"),
)
CondFunc(
"torch.instance_norm",
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
input,
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device),
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device),
*args,
**kwargs,
),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"),
)
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
# Functions with dtype errors:
CondFunc(
"torch.nn.modules.GroupNorm.forward",
lambda orig_func, self, input: orig_func(
self, input.to(self.weight.data.dtype)
),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
)
CondFunc(
"torch.nn.modules.linear.Linear.forward",
lambda orig_func, self, input: orig_func(
self, input.to(self.weight.data.dtype)
),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
)
CondFunc(
"torch.nn.modules.conv.Conv2d.forward",
lambda orig_func, self, input: orig_func(
self, input.to(self.weight.data.dtype)
),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
)
CondFunc(
"torch.nn.functional.layer_norm",
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(
input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs
),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight
is not None
and input.dtype != weight.data.dtype,
)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
# Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
CondFunc(
"torch.from_numpy",
lambda orig_func, ndarray: orig_func(ndarray.astype("float32")),
lambda orig_func, ndarray: ndarray.dtype == float,
)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
# Broken functions when torch.cuda.is_available is True:
CondFunc(
"torch.utils.data.dataloader._BaseDataLoaderIter.__init__",
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
lambda orig_func, *args, **kwargs: True,
)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
# Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = (
_shutdown_workers
)
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat

View File

@@ -17,12 +17,15 @@ n_gpus = len(hps.gpus.split("-"))
from random import randint, shuffle
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
from infer.modules.ipex.gradscaler import gradscaler_init
from torch.xpu.amp import autocast
GradScaler = gradscaler_init()
ipex_init()
else:

View File

@@ -288,14 +288,13 @@ class VC:
tgt_sr,
)
else:
path = "%s/%s.%s" % (opt_root, os.path.basename(path), format1)
path = "%s/%s.%s" % (
opt_root,
os.path.basename(path),
format1,
)
with BytesIO() as wavf:
sf.write(
wavf,
audio_opt,
tgt_sr,
format="wav"
)
sf.write(wavf, audio_opt, tgt_sr, format="wav")
wavf.seek(0, 0)
with open(path, "wb") as outf:
wav2(wavf, outf, format1)