chore(format): run black on main

This commit is contained in:
github-actions[bot]
2024-01-26 08:10:04 +00:00
parent 8790ab69e0
commit 005f097fec
15 changed files with 264 additions and 146 deletions

View File

@@ -62,12 +62,12 @@ def torch_bmm(input, mat2, *, out=None):
): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[
start_idx:end_idx, start_idx_2:end_idx_2
] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out,
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = (
original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out,
)
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
@@ -138,61 +138,67 @@ def scaled_dot_product_attention(
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
hidden_states[
start_idx:end_idx, start_idx_2:end_idx_2
] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[
start_idx:end_idx, start_idx_2:end_idx_2
]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = (
original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=(
attn_mask[start_idx:end_idx, start_idx_2:end_idx_2]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
hidden_states[
:, start_idx:end_idx, start_idx_2:end_idx_2
] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[
:, start_idx:end_idx, start_idx_2:end_idx_2
]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = (
original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=(
attn_mask[
:, start_idx:end_idx, start_idx_2:end_idx_2
]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
if no_shape_one:
hidden_states[
start_idx:end_idx
] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[start_idx:end_idx] = (
original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=(
attn_mask[start_idx:end_idx]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
hidden_states[
:, start_idx:end_idx
] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx]
if attn_mask is not None
else attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
hidden_states[:, start_idx:end_idx] = (
original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=(
attn_mask[:, start_idx:end_idx]
if attn_mask is not None
else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
)
)
else:
return original_scaled_dot_product_attention(

View File

@@ -104,11 +104,11 @@ def return_xpu(device):
return (
f"xpu:{device[-1]}"
if isinstance(device, str) and ":" in device
else f"xpu:{device}"
if isinstance(device, int)
else torch.device("xpu")
if isinstance(device, torch.device)
else "xpu"
else (
f"xpu:{device}"
if isinstance(device, int)
else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
)
)
@@ -271,12 +271,16 @@ def ipex_hijacks():
"torch.batch_norm",
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
input,
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device),
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device),
(
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device)
),
(
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device)
),
*args,
**kwargs,
),
@@ -286,12 +290,16 @@ def ipex_hijacks():
"torch.instance_norm",
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
input,
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device),
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device),
(
weight
if weight is not None
else torch.ones(input.size()[1], device=input.device)
),
(
bias
if bias is not None
else torch.zeros(input.size()[1], device=input.device)
),
*args,
**kwargs,
),