Skip to content

Commit 0cc3a7a

Browse files
Make sure we also change the config when setting encoder_hid_dim_type=="text_proj" and allow xformers (huggingface#3615)
* fix if * make style * make style * add tests for xformers * make style * update
1 parent 9d3ff07 commit 0cc3a7a

File tree

11 files changed

+141
-25
lines changed

11 files changed

+141
-25
lines changed

examples/community/mixture.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,8 @@ def __call__(
215215
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
216216
if isinstance(seed_tiles_mode, str):
217217
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
218-
if any(
219-
mode not in (modes := [mode.value for mode in self.SeedTilesMode])
220-
for row in seed_tiles_mode
221-
for mode in row
222-
):
218+
modes = [mode.value for mode in self.SeedTilesMode]
219+
if any(mode not in modes for row in seed_tiles_mode for mode in row):
223220
raise ValueError(f"Seed tiles mode must be one of {modes}")
224221
if seed_reroll_regions is None:
225222
seed_reroll_regions = []

frog.png

108 KB
Loading

src/diffusers/models/attention_processor.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,22 +166,28 @@ def set_use_memory_efficient_attention_xformers(
166166
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
167167
):
168168
is_lora = hasattr(self, "processor") and isinstance(
169-
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
169+
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
170170
)
171171
is_custom_diffusion = hasattr(self, "processor") and isinstance(
172172
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
173173
)
174+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
175+
self.processor,
176+
(
177+
AttnAddedKVProcessor,
178+
AttnAddedKVProcessor2_0,
179+
SlicedAttnAddedKVProcessor,
180+
XFormersAttnAddedKVProcessor,
181+
LoRAAttnAddedKVProcessor,
182+
),
183+
)
174184

175185
if use_memory_efficient_attention_xformers:
176-
if self.added_kv_proj_dim is not None:
177-
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
178-
# which uses this type of cross attention ONLY because the attention mask of format
179-
# [0, ..., -10.000, ..., 0, ...,] is not supported
186+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
180187
raise NotImplementedError(
181-
"Memory efficient attention with `xformers` is currently not supported when"
182-
" `self.added_kv_proj_dim` is defined."
188+
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
183189
)
184-
elif not is_xformers_available():
190+
if not is_xformers_available():
185191
raise ModuleNotFoundError(
186192
(
187193
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
@@ -233,6 +239,15 @@ def set_use_memory_efficient_attention_xformers(
233239
processor.load_state_dict(self.processor.state_dict())
234240
if hasattr(self.processor, "to_k_custom_diffusion"):
235241
processor.to(self.processor.to_k_custom_diffusion.weight.device)
242+
elif is_added_kv_processor:
243+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
244+
# which uses this type of cross attention ONLY because the attention mask of format
245+
# [0, ..., -10.000, ..., 0, ...,] is not supported
246+
# throw warning
247+
logger.info(
248+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
249+
)
250+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
236251
else:
237252
processor = XFormersAttnProcessor(attention_op=attention_op)
238253
else:
@@ -889,6 +904,71 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
889904
return hidden_states
890905

891906

907+
class XFormersAttnAddedKVProcessor:
908+
r"""
909+
Processor for implementing memory efficient attention using xFormers.
910+
911+
Args:
912+
attention_op (`Callable`, *optional*, defaults to `None`):
913+
The base
914+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
915+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
916+
operator.
917+
"""
918+
919+
def __init__(self, attention_op: Optional[Callable] = None):
920+
self.attention_op = attention_op
921+
922+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
923+
residual = hidden_states
924+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
925+
batch_size, sequence_length, _ = hidden_states.shape
926+
927+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
928+
929+
if encoder_hidden_states is None:
930+
encoder_hidden_states = hidden_states
931+
elif attn.norm_cross:
932+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
933+
934+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
935+
936+
query = attn.to_q(hidden_states)
937+
query = attn.head_to_batch_dim(query)
938+
939+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
940+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
941+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
942+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
943+
944+
if not attn.only_cross_attention:
945+
key = attn.to_k(hidden_states)
946+
value = attn.to_v(hidden_states)
947+
key = attn.head_to_batch_dim(key)
948+
value = attn.head_to_batch_dim(value)
949+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
950+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
951+
else:
952+
key = encoder_hidden_states_key_proj
953+
value = encoder_hidden_states_value_proj
954+
955+
hidden_states = xformers.ops.memory_efficient_attention(
956+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
957+
)
958+
hidden_states = hidden_states.to(query.dtype)
959+
hidden_states = attn.batch_to_head_dim(hidden_states)
960+
961+
# linear proj
962+
hidden_states = attn.to_out[0](hidden_states)
963+
# dropout
964+
hidden_states = attn.to_out[1](hidden_states)
965+
966+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
967+
hidden_states = hidden_states + residual
968+
969+
return hidden_states
970+
971+
892972
class XFormersAttnProcessor:
893973
r"""
894974
Processor for implementing memory efficient attention using xFormers.
@@ -1428,6 +1508,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
14281508
AttnAddedKVProcessor,
14291509
SlicedAttnAddedKVProcessor,
14301510
AttnAddedKVProcessor2_0,
1511+
XFormersAttnAddedKVProcessor,
14311512
LoRAAttnProcessor,
14321513
LoRAXFormersAttnProcessor,
14331514
LoRAAttnAddedKVProcessor,

src/diffusers/models/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def __init__(
261261

262262
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
263263
encoder_hid_dim_type = "text_proj"
264+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
264265
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
265266

266267
if encoder_hid_dim is None and encoder_hid_dim_type is not None:

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __init__(
364364

365365
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
366366
encoder_hid_dim_type = "text_proj"
367+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
367368
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
368369

369370
if encoder_hid_dim is None and encoder_hid_dim_type is not None:

tests/pipelines/deepfloyd_if/test_if.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
IFSuperResolutionPipeline,
2929
)
3030
from diffusers.models.attention_processor import AttnAddedKVProcessor
31+
from diffusers.utils.import_utils import is_xformers_available
3132
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
3233

3334
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
4243
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
4344
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
4445

45-
test_xformers_attention = False
46-
4746
def get_dummy_components(self):
4847
return self._get_dummy_components()
4948

@@ -81,6 +80,13 @@ def test_inference_batch_single_identical(self):
8180
expected_max_diff=1e-2,
8281
)
8382

83+
@unittest.skipIf(
84+
torch_device != "cuda" or not is_xformers_available(),
85+
reason="XFormers attention is only available with CUDA and `xformers` installed",
86+
)
87+
def test_xformers_attention_forwardGenerator_pass(self):
88+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
89+
8490

8591
@slow
8692
@require_torch_gpu

tests/pipelines/deepfloyd_if/test_if_img2img.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from diffusers import IFImg2ImgPipeline
2222
from diffusers.utils import floats_tensor
23+
from diffusers.utils.import_utils import is_xformers_available
2324
from diffusers.utils.testing_utils import skip_mps, torch_device
2425

2526
from ..pipeline_params import (
@@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
3738
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
3839
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
3940

40-
test_xformers_attention = False
41-
4241
def get_dummy_components(self):
4342
return self._get_dummy_components()
4443

@@ -63,6 +62,13 @@ def get_dummy_inputs(self, device, seed=0):
6362
def test_save_load_optional_components(self):
6463
self._test_save_load_optional_components()
6564

65+
@unittest.skipIf(
66+
torch_device != "cuda" or not is_xformers_available(),
67+
reason="XFormers attention is only available with CUDA and `xformers` installed",
68+
)
69+
def test_xformers_attention_forwardGenerator_pass(self):
70+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
71+
6672
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
6773
def test_save_load_float16(self):
6874
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder

tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from diffusers import IFImg2ImgSuperResolutionPipeline
2222
from diffusers.utils import floats_tensor
23+
from diffusers.utils.import_utils import is_xformers_available
2324
from diffusers.utils.testing_utils import skip_mps, torch_device
2425

2526
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
@@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
3435
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"})
3536
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
3637

37-
test_xformers_attention = False
38-
3938
def get_dummy_components(self):
4039
return self._get_superresolution_dummy_components()
4140

@@ -59,6 +58,13 @@ def get_dummy_inputs(self, device, seed=0):
5958

6059
return inputs
6160

61+
@unittest.skipIf(
62+
torch_device != "cuda" or not is_xformers_available(),
63+
reason="XFormers attention is only available with CUDA and `xformers` installed",
64+
)
65+
def test_xformers_attention_forwardGenerator_pass(self):
66+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
67+
6268
def test_save_load_optional_components(self):
6369
self._test_save_load_optional_components()
6470

tests/pipelines/deepfloyd_if/test_if_inpainting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from diffusers import IFInpaintingPipeline
2222
from diffusers.utils import floats_tensor
23+
from diffusers.utils.import_utils import is_xformers_available
2324
from diffusers.utils.testing_utils import skip_mps, torch_device
2425

2526
from ..pipeline_params import (
@@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
3738
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
3839
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
3940

40-
test_xformers_attention = False
41-
4241
def get_dummy_components(self):
4342
return self._get_dummy_components()
4443

@@ -62,6 +61,13 @@ def get_dummy_inputs(self, device, seed=0):
6261

6362
return inputs
6463

64+
@unittest.skipIf(
65+
torch_device != "cuda" or not is_xformers_available(),
66+
reason="XFormers attention is only available with CUDA and `xformers` installed",
67+
)
68+
def test_xformers_attention_forwardGenerator_pass(self):
69+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
70+
6571
def test_save_load_optional_components(self):
6672
self._test_save_load_optional_components()
6773

tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from diffusers import IFInpaintingSuperResolutionPipeline
2222
from diffusers.utils import floats_tensor
23+
from diffusers.utils.import_utils import is_xformers_available
2324
from diffusers.utils.testing_utils import skip_mps, torch_device
2425

2526
from ..pipeline_params import (
@@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
3738
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"})
3839
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
3940

40-
test_xformers_attention = False
41-
4241
def get_dummy_components(self):
4342
return self._get_superresolution_dummy_components()
4443

@@ -64,6 +63,13 @@ def get_dummy_inputs(self, device, seed=0):
6463

6564
return inputs
6665

66+
@unittest.skipIf(
67+
torch_device != "cuda" or not is_xformers_available(),
68+
reason="XFormers attention is only available with CUDA and `xformers` installed",
69+
)
70+
def test_xformers_attention_forwardGenerator_pass(self):
71+
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
72+
6773
def test_save_load_optional_components(self):
6874
self._test_save_load_optional_components()
6975

0 commit comments

Comments
 (0)