Skip to content

Commit 170ebd2

Browse files
authored
[UNet2DConditionModel] add an option to upcast attention to fp32 (huggingface#1590)
upcast attention
1 parent dc87f52 commit 170ebd2

File tree

4 files changed

+50
-1
lines changed

4 files changed

+50
-1
lines changed

src/diffusers/models/attention.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
num_embeds_ada_norm: Optional[int] = None,
102102
use_linear_projection: bool = False,
103103
only_cross_attention: bool = False,
104+
upcast_attention: bool = False,
104105
):
105106
super().__init__()
106107
self.use_linear_projection = use_linear_projection
@@ -159,6 +160,7 @@ def __init__(
159160
num_embeds_ada_norm=num_embeds_ada_norm,
160161
attention_bias=attention_bias,
161162
only_cross_attention=only_cross_attention,
163+
upcast_attention=upcast_attention,
162164
)
163165
for d in range(num_layers)
164166
]
@@ -403,6 +405,7 @@ def __init__(
403405
num_embeds_ada_norm: Optional[int] = None,
404406
attention_bias: bool = False,
405407
only_cross_attention: bool = False,
408+
upcast_attention: bool = False,
406409
):
407410
super().__init__()
408411
self.only_cross_attention = only_cross_attention
@@ -416,6 +419,7 @@ def __init__(
416419
dropout=dropout,
417420
bias=attention_bias,
418421
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
422+
upcast_attention=upcast_attention,
419423
) # is a self-attention
420424
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
421425

@@ -428,6 +432,7 @@ def __init__(
428432
dim_head=attention_head_dim,
429433
dropout=dropout,
430434
bias=attention_bias,
435+
upcast_attention=upcast_attention,
431436
) # is self-attn if context is none
432437
else:
433438
self.attn2 = None
@@ -525,10 +530,12 @@ def __init__(
525530
dim_head: int = 64,
526531
dropout: float = 0.0,
527532
bias=False,
533+
upcast_attention: bool = False,
528534
):
529535
super().__init__()
530536
inner_dim = dim_head * heads
531537
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
538+
self.upcast_attention = upcast_attention
532539

533540
self.scale = dim_head**-0.5
534541
self.heads = heads
@@ -601,6 +608,10 @@ def forward(self, hidden_states, context=None, mask=None):
601608
return hidden_states
602609

603610
def _attention(self, query, key, value):
611+
if self.upcast_attention:
612+
query = query.float()
613+
key = key.float()
614+
604615
attention_scores = torch.baddbmm(
605616
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
606617
query,
@@ -609,8 +620,11 @@ def _attention(self, query, key, value):
609620
alpha=self.scale,
610621
)
611622
attention_probs = attention_scores.softmax(dim=-1)
612-
# compute attention output
613623

624+
# cast back to the original dtype
625+
attention_probs = attention_probs.to(value.dtype)
626+
627+
# compute attention output
614628
hidden_states = torch.bmm(attention_probs, value)
615629

616630
# reshape hidden_states
@@ -626,6 +640,14 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
626640
for i in range(hidden_states.shape[0] // slice_size):
627641
start_idx = i * slice_size
628642
end_idx = (i + 1) * slice_size
643+
644+
query_slice = query[start_idx:end_idx]
645+
key_slice = key[start_idx:end_idx]
646+
647+
if self.upcast_attention:
648+
query_slice = query_slice.float()
649+
key_slice = key_slice.float()
650+
629651
attn_slice = torch.baddbmm(
630652
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
631653
query[start_idx:end_idx],
@@ -634,6 +656,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
634656
alpha=self.scale,
635657
)
636658
attn_slice = attn_slice.softmax(dim=-1)
659+
660+
# cast back to the original dtype
661+
attn_slice = attn_slice.to(value.dtype)
637662
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
638663

639664
hidden_states[start_idx:end_idx] = attn_slice

src/diffusers/models/unet_2d_blocks.py

+10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_down_block(
3535
dual_cross_attention=False,
3636
use_linear_projection=False,
3737
only_cross_attention=False,
38+
upcast_attention=False,
3839
):
3940
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
4041
if down_block_type == "DownBlock2D":
@@ -80,6 +81,7 @@ def get_down_block(
8081
dual_cross_attention=dual_cross_attention,
8182
use_linear_projection=use_linear_projection,
8283
only_cross_attention=only_cross_attention,
84+
upcast_attention=upcast_attention,
8385
)
8486
elif down_block_type == "SkipDownBlock2D":
8587
return SkipDownBlock2D(
@@ -146,6 +148,7 @@ def get_up_block(
146148
dual_cross_attention=False,
147149
use_linear_projection=False,
148150
only_cross_attention=False,
151+
upcast_attention=False,
149152
):
150153
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
151154
if up_block_type == "UpBlock2D":
@@ -178,6 +181,7 @@ def get_up_block(
178181
dual_cross_attention=dual_cross_attention,
179182
use_linear_projection=use_linear_projection,
180183
only_cross_attention=only_cross_attention,
184+
upcast_attention=upcast_attention,
181185
)
182186
elif up_block_type == "AttnUpBlock2D":
183187
return AttnUpBlock2D(
@@ -335,6 +339,7 @@ def __init__(
335339
cross_attention_dim=1280,
336340
dual_cross_attention=False,
337341
use_linear_projection=False,
342+
upcast_attention=False,
338343
):
339344
super().__init__()
340345

@@ -370,6 +375,7 @@ def __init__(
370375
cross_attention_dim=cross_attention_dim,
371376
norm_num_groups=resnet_groups,
372377
use_linear_projection=use_linear_projection,
378+
upcast_attention=upcast_attention,
373379
)
374380
)
375381
else:
@@ -514,6 +520,7 @@ def __init__(
514520
dual_cross_attention=False,
515521
use_linear_projection=False,
516522
only_cross_attention=False,
523+
upcast_attention=False,
517524
):
518525
super().__init__()
519526
resnets = []
@@ -549,6 +556,7 @@ def __init__(
549556
norm_num_groups=resnet_groups,
550557
use_linear_projection=use_linear_projection,
551558
only_cross_attention=only_cross_attention,
559+
upcast_attention=upcast_attention,
552560
)
553561
)
554562
else:
@@ -1096,6 +1104,7 @@ def __init__(
10961104
dual_cross_attention=False,
10971105
use_linear_projection=False,
10981106
only_cross_attention=False,
1107+
upcast_attention=False,
10991108
):
11001109
super().__init__()
11011110
resnets = []
@@ -1133,6 +1142,7 @@ def __init__(
11331142
norm_num_groups=resnet_groups,
11341143
use_linear_projection=use_linear_projection,
11351144
only_cross_attention=only_cross_attention,
1145+
upcast_attention=upcast_attention,
11361146
)
11371147
)
11381148
else:

src/diffusers/models/unet_2d_condition.py

+4
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
dual_cross_attention: bool = False,
112112
use_linear_projection: bool = False,
113113
num_class_embeds: Optional[int] = None,
114+
upcast_attention: bool = False,
114115
):
115116
super().__init__()
116117

@@ -163,6 +164,7 @@ def __init__(
163164
dual_cross_attention=dual_cross_attention,
164165
use_linear_projection=use_linear_projection,
165166
only_cross_attention=only_cross_attention[i],
167+
upcast_attention=upcast_attention,
166168
)
167169
self.down_blocks.append(down_block)
168170

@@ -179,6 +181,7 @@ def __init__(
179181
resnet_groups=norm_num_groups,
180182
dual_cross_attention=dual_cross_attention,
181183
use_linear_projection=use_linear_projection,
184+
upcast_attention=upcast_attention,
182185
)
183186

184187
# count how many layers upsample the images
@@ -219,6 +222,7 @@ def __init__(
219222
dual_cross_attention=dual_cross_attention,
220223
use_linear_projection=use_linear_projection,
221224
only_cross_attention=only_cross_attention[i],
225+
upcast_attention=upcast_attention,
222226
)
223227
self.up_blocks.append(up_block)
224228
prev_output_channel = output_channel

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

+10
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def __init__(
189189
dual_cross_attention: bool = False,
190190
use_linear_projection: bool = False,
191191
num_class_embeds: Optional[int] = None,
192+
upcast_attention: bool = False,
192193
):
193194
super().__init__()
194195

@@ -241,6 +242,7 @@ def __init__(
241242
dual_cross_attention=dual_cross_attention,
242243
use_linear_projection=use_linear_projection,
243244
only_cross_attention=only_cross_attention[i],
245+
upcast_attention=upcast_attention,
244246
)
245247
self.down_blocks.append(down_block)
246248

@@ -257,6 +259,7 @@ def __init__(
257259
resnet_groups=norm_num_groups,
258260
dual_cross_attention=dual_cross_attention,
259261
use_linear_projection=use_linear_projection,
262+
upcast_attention=upcast_attention,
260263
)
261264

262265
# count how many layers upsample the images
@@ -297,6 +300,7 @@ def __init__(
297300
dual_cross_attention=dual_cross_attention,
298301
use_linear_projection=use_linear_projection,
299302
only_cross_attention=only_cross_attention[i],
303+
upcast_attention=upcast_attention,
300304
)
301305
self.up_blocks.append(up_block)
302306
prev_output_channel = output_channel
@@ -716,6 +720,7 @@ def __init__(
716720
dual_cross_attention=False,
717721
use_linear_projection=False,
718722
only_cross_attention=False,
723+
upcast_attention=False,
719724
):
720725
super().__init__()
721726
resnets = []
@@ -751,6 +756,7 @@ def __init__(
751756
norm_num_groups=resnet_groups,
752757
use_linear_projection=use_linear_projection,
753758
only_cross_attention=only_cross_attention,
759+
upcast_attention=upcast_attention,
754760
)
755761
)
756762
else:
@@ -912,6 +918,7 @@ def __init__(
912918
dual_cross_attention=False,
913919
use_linear_projection=False,
914920
only_cross_attention=False,
921+
upcast_attention=False,
915922
):
916923
super().__init__()
917924
resnets = []
@@ -949,6 +956,7 @@ def __init__(
949956
norm_num_groups=resnet_groups,
950957
use_linear_projection=use_linear_projection,
951958
only_cross_attention=only_cross_attention,
959+
upcast_attention=upcast_attention,
952960
)
953961
)
954962
else:
@@ -1031,6 +1039,7 @@ def __init__(
10311039
cross_attention_dim=1280,
10321040
dual_cross_attention=False,
10331041
use_linear_projection=False,
1042+
upcast_attention=False,
10341043
):
10351044
super().__init__()
10361045

@@ -1066,6 +1075,7 @@ def __init__(
10661075
cross_attention_dim=cross_attention_dim,
10671076
norm_num_groups=resnet_groups,
10681077
use_linear_projection=use_linear_projection,
1078+
upcast_attention=upcast_attention,
10691079
)
10701080
)
10711081
else:

0 commit comments

Comments
 (0)