@@ -33,6 +33,7 @@ def get_down_block(
3333 use_linear_projection = False ,
3434 only_cross_attention = False ,
3535 upcast_attention = False ,
36+ resnet_time_scale_shift = "default" ,
3637):
3738 down_block_type = down_block_type [7 :] if down_block_type .startswith ("UNetRes" ) else down_block_type
3839 if down_block_type == "DownBlockFlat" :
@@ -46,6 +47,7 @@ def get_down_block(
4647 resnet_act_fn = resnet_act_fn ,
4748 resnet_groups = resnet_groups ,
4849 downsample_padding = downsample_padding ,
50+ resnet_time_scale_shift = resnet_time_scale_shift ,
4951 )
5052 elif down_block_type == "CrossAttnDownBlockFlat" :
5153 if cross_attention_dim is None :
@@ -65,6 +67,7 @@ def get_down_block(
6567 dual_cross_attention = dual_cross_attention ,
6668 use_linear_projection = use_linear_projection ,
6769 only_cross_attention = only_cross_attention ,
70+ resnet_time_scale_shift = resnet_time_scale_shift ,
6871 )
6972 raise ValueError (f"{ down_block_type } is not supported." )
7073
@@ -86,6 +89,7 @@ def get_up_block(
8689 use_linear_projection = False ,
8790 only_cross_attention = False ,
8891 upcast_attention = False ,
92+ resnet_time_scale_shift = "default" ,
8993):
9094 up_block_type = up_block_type [7 :] if up_block_type .startswith ("UNetRes" ) else up_block_type
9195 if up_block_type == "UpBlockFlat" :
@@ -99,6 +103,7 @@ def get_up_block(
99103 resnet_eps = resnet_eps ,
100104 resnet_act_fn = resnet_act_fn ,
101105 resnet_groups = resnet_groups ,
106+ resnet_time_scale_shift = resnet_time_scale_shift ,
102107 )
103108 elif up_block_type == "CrossAttnUpBlockFlat" :
104109 if cross_attention_dim is None :
@@ -118,6 +123,7 @@ def get_up_block(
118123 dual_cross_attention = dual_cross_attention ,
119124 use_linear_projection = use_linear_projection ,
120125 only_cross_attention = only_cross_attention ,
126+ resnet_time_scale_shift = resnet_time_scale_shift ,
121127 )
122128 raise ValueError (f"{ up_block_type } is not supported." )
123129
0 commit comments