Skip to content

Commit ae3f625

Browse files
committed
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev
2 parents f1f30ab + b86af67 commit ae3f625

File tree

3 files changed

+33
-20
lines changed

3 files changed

+33
-20
lines changed

finetune/tag_images_by_wd14_tagger.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,23 @@ def main(args):
142142

143143
del model
144144

145-
ort_sess = ort.InferenceSession(
146-
onnx_path,
147-
providers=(
148-
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
149-
),
150-
)
145+
if "OpenVINOExecutionProvider" in ort.get_available_providers():
146+
# requires provider options for gpu support
147+
# fp16 causes nonsense outputs
148+
ort_sess = ort.InferenceSession(
149+
onnx_path,
150+
providers=(["OpenVINOExecutionProvider"]),
151+
provider_options=[{'device_type' : "GPU_FP32"}],
152+
)
153+
else:
154+
ort_sess = ort.InferenceSession(
155+
onnx_path,
156+
providers=(
157+
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
158+
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
159+
["CPUExecutionProvider"]
160+
),
161+
)
151162
else:
152163
from tensorflow.keras.models import load_model
153164

library/ipex/attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
122122
mat2[start_idx:end_idx],
123123
out=out
124124
)
125+
torch.xpu.synchronize(input.device)
125126
else:
126127
return original_torch_bmm(input, mat2, out=out)
127-
torch.xpu.synchronize(input.device)
128128
return hidden_states
129129

130130
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
131-
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
131+
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
132132
if query.device.type != "xpu":
133-
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
133+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
134134
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
135135

136136
# Slice SDPA
@@ -153,25 +153,25 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
153153
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
154154
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
155155
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
156-
dropout_p=dropout_p, is_causal=is_causal
156+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
157157
)
158158
else:
159159
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
160160
query[start_idx:end_idx, start_idx_2:end_idx_2],
161161
key[start_idx:end_idx, start_idx_2:end_idx_2],
162162
value[start_idx:end_idx, start_idx_2:end_idx_2],
163163
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
164-
dropout_p=dropout_p, is_causal=is_causal
164+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
165165
)
166166
else:
167167
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
168168
query[start_idx:end_idx],
169169
key[start_idx:end_idx],
170170
value[start_idx:end_idx],
171171
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
172-
dropout_p=dropout_p, is_causal=is_causal
172+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
173173
)
174+
torch.xpu.synchronize(query.device)
174175
else:
175-
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
176-
torch.xpu.synchronize(query.device)
176+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
177177
return hidden_states

library/ipex/hijacks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
1313
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
1414
if isinstance(device_ids, list) and len(device_ids) > 1:
15-
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
15+
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
1616
return module.to("xpu")
1717

1818
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
@@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
4242
original_interpolate = torch.nn.functional.interpolate
4343
@wraps(torch.nn.functional.interpolate)
4444
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
45-
if antialias or align_corners is not None:
45+
if antialias or align_corners is not None or mode == 'bicubic':
4646
return_device = tensor.device
4747
return_dtype = tensor.dtype
4848
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -216,7 +216,9 @@ def torch_empty(*args, device=None, **kwargs):
216216

217217
original_torch_randn = torch.randn
218218
@wraps(torch.randn)
219-
def torch_randn(*args, device=None, **kwargs):
219+
def torch_randn(*args, device=None, dtype=None, **kwargs):
220+
if dtype == bytes:
221+
dtype = None
220222
if check_device(device):
221223
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
222224
else:
@@ -256,11 +258,11 @@ def torch_Generator(device=None):
256258

257259
original_torch_load = torch.load
258260
@wraps(torch.load)
259-
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
261+
def torch_load(f, map_location=None, *args, **kwargs):
260262
if check_device(map_location):
261-
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
263+
return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
262264
else:
263-
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
265+
return original_torch_load(f, map_location=map_location, *args, **kwargs)
264266

265267

266268
# Hijack Functions:

0 commit comments

Comments
 (0)