Skip to content

Commit 5e03692

Browse files
authored
Make cross-attention check more robust (huggingface#1560)
* Make cross-attention check more robust. * Fix copies.
1 parent bea7eb4 commit 5e03692

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/diffusers/models/unet_2d_blocks.py

+3
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def __init__(
343343
):
344344
super().__init__()
345345

346+
self.has_cross_attention = True
346347
self.attention_type = attention_type
347348
self.attn_num_head_channels = attn_num_head_channels
348349
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -526,6 +527,7 @@ def __init__(
526527
resnets = []
527528
attentions = []
528529

530+
self.has_cross_attention = True
529531
self.attention_type = attention_type
530532
self.attn_num_head_channels = attn_num_head_channels
531533

@@ -1110,6 +1112,7 @@ def __init__(
11101112
resnets = []
11111113
attentions = []
11121114

1115+
self.has_cross_attention = True
11131116
self.attention_type = attention_type
11141117
self.attn_num_head_channels = attn_num_head_channels
11151118

src/diffusers/models/unet_2d_condition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def forward(
377377
# 3. down
378378
down_block_res_samples = (sample,)
379379
for downsample_block in self.down_blocks:
380-
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
380+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
381381
sample, res_samples = downsample_block(
382382
hidden_states=sample,
383383
temb=emb,
@@ -403,7 +403,7 @@ def forward(
403403
if not is_final_block and forward_upsample_size:
404404
upsample_size = down_block_res_samples[-1].shape[2:]
405405

406-
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
406+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
407407
sample = upsample_block(
408408
hidden_states=sample,
409409
temb=emb,

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def forward(
455455
# 3. down
456456
down_block_res_samples = (sample,)
457457
for downsample_block in self.down_blocks:
458-
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
458+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
459459
sample, res_samples = downsample_block(
460460
hidden_states=sample,
461461
temb=emb,
@@ -481,7 +481,7 @@ def forward(
481481
if not is_final_block and forward_upsample_size:
482482
upsample_size = down_block_res_samples[-1].shape[2:]
483483

484-
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
484+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
485485
sample = upsample_block(
486486
hidden_states=sample,
487487
temb=emb,
@@ -726,6 +726,7 @@ def __init__(
726726
resnets = []
727727
attentions = []
728728

729+
self.has_cross_attention = True
729730
self.attention_type = attention_type
730731
self.attn_num_head_channels = attn_num_head_channels
731732

@@ -924,6 +925,7 @@ def __init__(
924925
resnets = []
925926
attentions = []
926927

928+
self.has_cross_attention = True
927929
self.attention_type = attention_type
928930
self.attn_num_head_channels = attn_num_head_channels
929931

@@ -1043,6 +1045,7 @@ def __init__(
10431045
):
10441046
super().__init__()
10451047

1048+
self.has_cross_attention = True
10461049
self.attention_type = attention_type
10471050
self.attn_num_head_channels = attn_num_head_channels
10481051
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)

0 commit comments

Comments
 (0)