@@ -114,6 +114,7 @@ def get_down_block_adapter(
114114 cross_attention_dim : Optional [int ] = 1024 ,
115115 add_downsample : bool = True ,
116116 upcast_attention : Optional [bool ] = False ,
117+ use_linear_projection : Optional [bool ] = True ,
117118):
118119 num_layers = 2 # only support sd + sdxl
119120
@@ -152,7 +153,7 @@ def get_down_block_adapter(
152153 in_channels = ctrl_out_channels ,
153154 num_layers = transformer_layers_per_block [i ],
154155 cross_attention_dim = cross_attention_dim ,
155- use_linear_projection = True ,
156+ use_linear_projection = use_linear_projection ,
156157 upcast_attention = upcast_attention ,
157158 norm_num_groups = find_largest_factor (ctrl_out_channels , max_factor = max_norm_num_groups ),
158159 )
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
200201 num_attention_heads : Optional [int ] = 1 ,
201202 cross_attention_dim : Optional [int ] = 1024 ,
202203 upcast_attention : bool = False ,
204+ use_linear_projection : bool = True ,
203205):
204206 # Before the midblock application, information is concatted from base to control.
205207 # Concat doesn't require change in number of channels
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
214216 resnet_groups = find_largest_factor (gcd (ctrl_channels , ctrl_channels + base_channels ), max_norm_num_groups ),
215217 cross_attention_dim = cross_attention_dim ,
216218 num_attention_heads = num_attention_heads ,
217- use_linear_projection = True ,
219+ use_linear_projection = use_linear_projection ,
218220 upcast_attention = upcast_attention ,
219221 )
220222
@@ -308,6 +310,7 @@ def __init__(
308310 transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
309311 upcast_attention : bool = True ,
310312 max_norm_num_groups : int = 32 ,
313+ use_linear_projection : bool = True ,
311314 ):
312315 super ().__init__ ()
313316
@@ -381,6 +384,7 @@ def __init__(
381384 cross_attention_dim = cross_attention_dim [i ],
382385 add_downsample = not is_final_block ,
383386 upcast_attention = upcast_attention ,
387+ use_linear_projection = use_linear_projection ,
384388 )
385389 )
386390
@@ -393,6 +397,7 @@ def __init__(
393397 num_attention_heads = num_attention_heads [- 1 ],
394398 cross_attention_dim = cross_attention_dim [- 1 ],
395399 upcast_attention = upcast_attention ,
400+ use_linear_projection = use_linear_projection ,
396401 )
397402
398403 # up
@@ -489,6 +494,7 @@ def from_unet(
489494 transformer_layers_per_block = unet .config .transformer_layers_per_block ,
490495 upcast_attention = unet .config .upcast_attention ,
491496 max_norm_num_groups = unet .config .norm_num_groups ,
497+ use_linear_projection = unet .config .use_linear_projection ,
492498 )
493499
494500 # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
@@ -538,6 +544,7 @@ def __init__(
538544 addition_embed_type : Optional [str ] = None ,
539545 addition_time_embed_dim : Optional [int ] = None ,
540546 upcast_attention : bool = True ,
547+ use_linear_projection : bool = True ,
541548 time_cond_proj_dim : Optional [int ] = None ,
542549 projection_class_embeddings_input_dim : Optional [int ] = None ,
543550 # additional controlnet configs
@@ -595,7 +602,12 @@ def __init__(
595602 time_embed_dim ,
596603 cond_proj_dim = time_cond_proj_dim ,
597604 )
598- self .ctrl_time_embedding = TimestepEmbedding (in_channels = time_embed_input_dim , time_embed_dim = time_embed_dim )
605+ if ctrl_learn_time_embedding :
606+ self .ctrl_time_embedding = TimestepEmbedding (
607+ in_channels = time_embed_input_dim , time_embed_dim = time_embed_dim
608+ )
609+ else :
610+ self .ctrl_time_embedding = None
599611
600612 if addition_embed_type is None :
601613 self .base_add_time_proj = None
@@ -632,6 +644,7 @@ def __init__(
632644 cross_attention_dim = cross_attention_dim [i ],
633645 add_downsample = not is_final_block ,
634646 upcast_attention = upcast_attention ,
647+ use_linear_projection = use_linear_projection ,
635648 )
636649 )
637650
@@ -647,6 +660,7 @@ def __init__(
647660 ctrl_num_attention_heads = ctrl_num_attention_heads [- 1 ],
648661 cross_attention_dim = cross_attention_dim [- 1 ],
649662 upcast_attention = upcast_attention ,
663+ use_linear_projection = use_linear_projection ,
650664 )
651665
652666 # # Create up blocks
@@ -690,6 +704,7 @@ def __init__(
690704 add_upsample = not is_final_block ,
691705 upcast_attention = upcast_attention ,
692706 norm_num_groups = norm_num_groups ,
707+ use_linear_projection = use_linear_projection ,
693708 )
694709 )
695710
@@ -754,6 +769,7 @@ def from_unet(
754769 "addition_embed_type" ,
755770 "addition_time_embed_dim" ,
756771 "upcast_attention" ,
772+ "use_linear_projection" ,
757773 "time_cond_proj_dim" ,
758774 "projection_class_embeddings_input_dim" ,
759775 ]
@@ -1219,6 +1235,7 @@ def __init__(
12191235 cross_attention_dim : Optional [int ] = 1024 ,
12201236 add_downsample : bool = True ,
12211237 upcast_attention : Optional [bool ] = False ,
1238+ use_linear_projection : Optional [bool ] = True ,
12221239 ):
12231240 super ().__init__ ()
12241241 base_resnets = []
@@ -1270,7 +1287,7 @@ def __init__(
12701287 in_channels = base_out_channels ,
12711288 num_layers = transformer_layers_per_block [i ],
12721289 cross_attention_dim = cross_attention_dim ,
1273- use_linear_projection = True ,
1290+ use_linear_projection = use_linear_projection ,
12741291 upcast_attention = upcast_attention ,
12751292 norm_num_groups = norm_num_groups ,
12761293 )
@@ -1282,7 +1299,7 @@ def __init__(
12821299 in_channels = ctrl_out_channels ,
12831300 num_layers = transformer_layers_per_block [i ],
12841301 cross_attention_dim = cross_attention_dim ,
1285- use_linear_projection = True ,
1302+ use_linear_projection = use_linear_projection ,
12861303 upcast_attention = upcast_attention ,
12871304 norm_num_groups = find_largest_factor (ctrl_out_channels , max_factor = ctrl_max_norm_num_groups ),
12881305 )
@@ -1342,13 +1359,15 @@ def get_first_cross_attention(block):
13421359 ctrl_num_attention_heads = get_first_cross_attention (ctrl_downblock ).heads
13431360 cross_attention_dim = get_first_cross_attention (base_downblock ).cross_attention_dim
13441361 upcast_attention = get_first_cross_attention (base_downblock ).upcast_attention
1362+ use_linear_projection = base_downblock .attentions [0 ].use_linear_projection
13451363 else :
13461364 has_crossattn = False
13471365 transformer_layers_per_block = None
13481366 base_num_attention_heads = None
13491367 ctrl_num_attention_heads = None
13501368 cross_attention_dim = None
13511369 upcast_attention = None
1370+ use_linear_projection = None
13521371 add_downsample = base_downblock .downsamplers is not None
13531372
13541373 # create model
@@ -1367,6 +1386,7 @@ def get_first_cross_attention(block):
13671386 cross_attention_dim = cross_attention_dim ,
13681387 add_downsample = add_downsample ,
13691388 upcast_attention = upcast_attention ,
1389+ use_linear_projection = use_linear_projection ,
13701390 )
13711391
13721392 # # load weights
@@ -1527,6 +1547,7 @@ def __init__(
15271547 ctrl_num_attention_heads : Optional [int ] = 1 ,
15281548 cross_attention_dim : Optional [int ] = 1024 ,
15291549 upcast_attention : bool = False ,
1550+ use_linear_projection : Optional [bool ] = True ,
15301551 ):
15311552 super ().__init__ ()
15321553
@@ -1541,7 +1562,7 @@ def __init__(
15411562 resnet_groups = norm_num_groups ,
15421563 cross_attention_dim = cross_attention_dim ,
15431564 num_attention_heads = base_num_attention_heads ,
1544- use_linear_projection = True ,
1565+ use_linear_projection = use_linear_projection ,
15451566 upcast_attention = upcast_attention ,
15461567 )
15471568
@@ -1556,7 +1577,7 @@ def __init__(
15561577 ),
15571578 cross_attention_dim = cross_attention_dim ,
15581579 num_attention_heads = ctrl_num_attention_heads ,
1559- use_linear_projection = True ,
1580+ use_linear_projection = use_linear_projection ,
15601581 upcast_attention = upcast_attention ,
15611582 )
15621583
@@ -1590,6 +1611,7 @@ def get_first_cross_attention(midblock):
15901611 ctrl_num_attention_heads = get_first_cross_attention (ctrl_midblock ).heads
15911612 cross_attention_dim = get_first_cross_attention (base_midblock ).cross_attention_dim
15921613 upcast_attention = get_first_cross_attention (base_midblock ).upcast_attention
1614+ use_linear_projection = base_midblock .attentions [0 ].use_linear_projection
15931615
15941616 # create model
15951617 model = cls (
@@ -1603,6 +1625,7 @@ def get_first_cross_attention(midblock):
16031625 ctrl_num_attention_heads = ctrl_num_attention_heads ,
16041626 cross_attention_dim = cross_attention_dim ,
16051627 upcast_attention = upcast_attention ,
1628+ use_linear_projection = use_linear_projection ,
16061629 )
16071630
16081631 # load weights
@@ -1677,6 +1700,7 @@ def __init__(
16771700 cross_attention_dim : int = 1024 ,
16781701 add_upsample : bool = True ,
16791702 upcast_attention : bool = False ,
1703+ use_linear_projection : Optional [bool ] = True ,
16801704 ):
16811705 super ().__init__ ()
16821706 resnets = []
@@ -1714,7 +1738,7 @@ def __init__(
17141738 in_channels = out_channels ,
17151739 num_layers = transformer_layers_per_block [i ],
17161740 cross_attention_dim = cross_attention_dim ,
1717- use_linear_projection = True ,
1741+ use_linear_projection = use_linear_projection ,
17181742 upcast_attention = upcast_attention ,
17191743 norm_num_groups = norm_num_groups ,
17201744 )
@@ -1753,12 +1777,14 @@ def get_first_cross_attention(block):
17531777 num_attention_heads = get_first_cross_attention (base_upblock ).heads
17541778 cross_attention_dim = get_first_cross_attention (base_upblock ).cross_attention_dim
17551779 upcast_attention = get_first_cross_attention (base_upblock ).upcast_attention
1780+ use_linear_projection = base_upblock .attentions [0 ].use_linear_projection
17561781 else :
17571782 has_crossattn = False
17581783 transformer_layers_per_block = None
17591784 num_attention_heads = None
17601785 cross_attention_dim = None
17611786 upcast_attention = None
1787+ use_linear_projection = None
17621788 add_upsample = base_upblock .upsamplers is not None
17631789
17641790 # create model
@@ -1776,6 +1802,7 @@ def get_first_cross_attention(block):
17761802 cross_attention_dim = cross_attention_dim ,
17771803 add_upsample = add_upsample ,
17781804 upcast_attention = upcast_attention ,
1805+ use_linear_projection = use_linear_projection ,
17791806 )
17801807
17811808 # load weights
0 commit comments