Skip to content

Commit ced7c96

Browse files
authored
fix upcast in slice attention (huggingface#1591)
* fix upcast in slice attention * fix dtype * add test * fix test
1 parent 8e74efa commit ced7c96

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/diffusers/models/attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
649649
key_slice = key_slice.float()
650650

651651
attn_slice = torch.baddbmm(
652-
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
653-
query[start_idx:end_idx],
654-
key[start_idx:end_idx].transpose(-1, -2),
652+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
653+
query_slice,
654+
key_slice.transpose(-1, -2),
655655
beta=0,
656656
alpha=self.scale,
657657
)

tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py

+19
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,25 @@ def test_stable_diffusion_v_pred_default(self):
265265
expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866])
266266
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
267267

268+
def test_stable_diffusion_v_pred_upcast_attention(self):
269+
sd_pipe = StableDiffusionPipeline.from_pretrained(
270+
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
271+
)
272+
sd_pipe = sd_pipe.to(torch_device)
273+
sd_pipe.enable_attention_slicing()
274+
sd_pipe.set_progress_bar_config(disable=None)
275+
276+
prompt = "A painting of a squirrel eating a burger"
277+
generator = torch.Generator(device=torch_device).manual_seed(0)
278+
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
279+
280+
image = output.images
281+
image_slice = image[0, 253:256, 253:256, -1]
282+
283+
assert image.shape == (1, 768, 768, 3)
284+
expected_slice = np.array([0.0461, 0.0483, 0.0566, 0.0512, 0.0446, 0.0751, 0.0664, 0.0551, 0.0488])
285+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
286+
268287
def test_stable_diffusion_v_pred_euler(self):
269288
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
270289
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)

0 commit comments

Comments
 (0)