Skip to content

Commit dc7cd89

Browse files
authored
Add resnet_time_scale_shift to VD layers (huggingface#1757)
1 parent 8890758 commit dc7cd89

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def get_down_block(
3333
use_linear_projection=False,
3434
only_cross_attention=False,
3535
upcast_attention=False,
36+
resnet_time_scale_shift="default",
3637
):
3738
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
3839
if down_block_type == "DownBlockFlat":
@@ -46,6 +47,7 @@ def get_down_block(
4647
resnet_act_fn=resnet_act_fn,
4748
resnet_groups=resnet_groups,
4849
downsample_padding=downsample_padding,
50+
resnet_time_scale_shift=resnet_time_scale_shift,
4951
)
5052
elif down_block_type == "CrossAttnDownBlockFlat":
5153
if cross_attention_dim is None:
@@ -65,6 +67,7 @@ def get_down_block(
6567
dual_cross_attention=dual_cross_attention,
6668
use_linear_projection=use_linear_projection,
6769
only_cross_attention=only_cross_attention,
70+
resnet_time_scale_shift=resnet_time_scale_shift,
6871
)
6972
raise ValueError(f"{down_block_type} is not supported.")
7073

@@ -86,6 +89,7 @@ def get_up_block(
8689
use_linear_projection=False,
8790
only_cross_attention=False,
8891
upcast_attention=False,
92+
resnet_time_scale_shift="default",
8993
):
9094
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
9195
if up_block_type == "UpBlockFlat":
@@ -99,6 +103,7 @@ def get_up_block(
99103
resnet_eps=resnet_eps,
100104
resnet_act_fn=resnet_act_fn,
101105
resnet_groups=resnet_groups,
106+
resnet_time_scale_shift=resnet_time_scale_shift,
102107
)
103108
elif up_block_type == "CrossAttnUpBlockFlat":
104109
if cross_attention_dim is None:
@@ -118,6 +123,7 @@ def get_up_block(
118123
dual_cross_attention=dual_cross_attention,
119124
use_linear_projection=use_linear_projection,
120125
only_cross_attention=only_cross_attention,
126+
resnet_time_scale_shift=resnet_time_scale_shift,
121127
)
122128
raise ValueError(f"{up_block_type} is not supported.")
123129

0 commit comments

Comments
 (0)