@@ -31,6 +31,7 @@ def get_down_block(
3131 resnet_eps ,
3232 resnet_act_fn ,
3333 attn_num_head_channels ,
34+ cross_attention_dim = None ,
3435 downsample_padding = None ,
3536):
3637 down_block_type = down_block_type [7 :] if down_block_type .startswith ("UNetRes" ) else down_block_type
@@ -58,6 +59,8 @@ def get_down_block(
5859 attn_num_head_channels = attn_num_head_channels ,
5960 )
6061 elif down_block_type == "CrossAttnDownBlock2D" :
62+ if cross_attention_dim is None :
63+ raise ValueError ("cross_attention_dim must be specified for CrossAttnUpBlock2D" )
6164 return CrossAttnDownBlock2D (
6265 num_layers = num_layers ,
6366 in_channels = in_channels ,
@@ -67,6 +70,7 @@ def get_down_block(
6770 resnet_eps = resnet_eps ,
6871 resnet_act_fn = resnet_act_fn ,
6972 downsample_padding = downsample_padding ,
73+ cross_attention_dim = cross_attention_dim ,
7074 attn_num_head_channels = attn_num_head_channels ,
7175 )
7276 elif down_block_type == "SkipDownBlock2D" :
@@ -115,6 +119,7 @@ def get_up_block(
115119 resnet_eps ,
116120 resnet_act_fn ,
117121 attn_num_head_channels ,
122+ cross_attention_dim = None ,
118123):
119124 up_block_type = up_block_type [7 :] if up_block_type .startswith ("UNetRes" ) else up_block_type
120125 if up_block_type == "UpBlock2D" :
@@ -129,6 +134,8 @@ def get_up_block(
129134 resnet_act_fn = resnet_act_fn ,
130135 )
131136 elif up_block_type == "CrossAttnUpBlock2D" :
137+ if cross_attention_dim is None :
138+ raise ValueError ("cross_attention_dim must be specified for CrossAttnUpBlock2D" )
132139 return CrossAttnUpBlock2D (
133140 num_layers = num_layers ,
134141 in_channels = in_channels ,
@@ -138,6 +145,7 @@ def get_up_block(
138145 add_upsample = add_upsample ,
139146 resnet_eps = resnet_eps ,
140147 resnet_act_fn = resnet_act_fn ,
148+ cross_attention_dim = cross_attention_dim ,
141149 attn_num_head_channels = attn_num_head_channels ,
142150 )
143151 elif up_block_type == "AttnUpBlock2D" :
0 commit comments