Skip to content

Commit c4a3b09

Browse files
authored
[UNet2DConditionModel] add cross_attention_dim as an argument (huggingface#155)
add cross_attention_dim as an argument
1 parent 616c3a4 commit c4a3b09

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
act_fn="silu",
2929
norm_num_groups=32,
3030
norm_eps=1e-5,
31+
cross_attention_dim=1280,
3132
attention_head_dim=8,
3233
):
3334
super().__init__()
@@ -64,6 +65,7 @@ def __init__(
6465
add_downsample=not is_final_block,
6566
resnet_eps=norm_eps,
6667
resnet_act_fn=act_fn,
68+
cross_attention_dim=cross_attention_dim,
6769
attn_num_head_channels=attention_head_dim,
6870
downsample_padding=downsample_padding,
6971
)
@@ -77,6 +79,7 @@ def __init__(
7779
resnet_act_fn=act_fn,
7880
output_scale_factor=mid_block_scale_factor,
7981
resnet_time_scale_shift="default",
82+
cross_attention_dim=cross_attention_dim,
8083
attn_num_head_channels=attention_head_dim,
8184
resnet_groups=norm_num_groups,
8285
)
@@ -101,6 +104,7 @@ def __init__(
101104
add_upsample=not is_final_block,
102105
resnet_eps=norm_eps,
103106
resnet_act_fn=act_fn,
107+
cross_attention_dim=cross_attention_dim,
104108
attn_num_head_channels=attention_head_dim,
105109
)
106110
self.up_blocks.append(up_block)

src/diffusers/models/unet_blocks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_down_block(
3131
resnet_eps,
3232
resnet_act_fn,
3333
attn_num_head_channels,
34+
cross_attention_dim=None,
3435
downsample_padding=None,
3536
):
3637
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
@@ -58,6 +59,8 @@ def get_down_block(
5859
attn_num_head_channels=attn_num_head_channels,
5960
)
6061
elif down_block_type == "CrossAttnDownBlock2D":
62+
if cross_attention_dim is None:
63+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
6164
return CrossAttnDownBlock2D(
6265
num_layers=num_layers,
6366
in_channels=in_channels,
@@ -67,6 +70,7 @@ def get_down_block(
6770
resnet_eps=resnet_eps,
6871
resnet_act_fn=resnet_act_fn,
6972
downsample_padding=downsample_padding,
73+
cross_attention_dim=cross_attention_dim,
7074
attn_num_head_channels=attn_num_head_channels,
7175
)
7276
elif down_block_type == "SkipDownBlock2D":
@@ -115,6 +119,7 @@ def get_up_block(
115119
resnet_eps,
116120
resnet_act_fn,
117121
attn_num_head_channels,
122+
cross_attention_dim=None,
118123
):
119124
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
120125
if up_block_type == "UpBlock2D":
@@ -129,6 +134,8 @@ def get_up_block(
129134
resnet_act_fn=resnet_act_fn,
130135
)
131136
elif up_block_type == "CrossAttnUpBlock2D":
137+
if cross_attention_dim is None:
138+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
132139
return CrossAttnUpBlock2D(
133140
num_layers=num_layers,
134141
in_channels=in_channels,
@@ -138,6 +145,7 @@ def get_up_block(
138145
add_upsample=add_upsample,
139146
resnet_eps=resnet_eps,
140147
resnet_act_fn=resnet_act_fn,
148+
cross_attention_dim=cross_attention_dim,
141149
attn_num_head_channels=attn_num_head_channels,
142150
)
143151
elif up_block_type == "AttnUpBlock2D":

0 commit comments

Comments
 (0)