Skip to content

Commit 8dba180

Browse files
Gothossayakpaul
andauthored
Added support to create asymmetrical U-Net structures (huggingface#5400)
* Added args, kwargs to ```U * Add UNetMidBlock2D as a supported mid block type * Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_blocks.py * Update unet_2d_condition.py * Update unet_2d_blocks.py * Updated docstring, increased check strictness Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block``` * Add basic shape-check test for asymmetrical unets * Update src/diffusers/models/unet_2d_blocks.py Removed blank line Co-authored-by: Sayak Paul <[email protected]> * Update unet_2d_condition.py Remove blank space * Update unet_2d_condition.py Changed docstring for `mid_block_type` * Fixed docstring and wrong default value * Reformat with black * Reformat with necessary commands * Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency * Removed args, kwargs, use on mid-block type * Make fix-copies * Update src/diffusers/models/unet_2d_condition.py Wrap into single line Co-authored-by: Sayak Paul <[email protected]> * make fix-copies --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 5366db5 commit 8dba180

File tree

7 files changed

+272
-46
lines changed

7 files changed

+272
-46
lines changed

src/diffusers/models/unet_2d_blocks.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Optional, Tuple
14+
from typing import Any, Dict, Optional, Tuple, Union
1515

1616
import numpy as np
1717
import torch
@@ -634,7 +634,7 @@ def __init__(
634634
temb_channels: int,
635635
dropout: float = 0.0,
636636
num_layers: int = 1,
637-
transformer_layers_per_block: int = 1,
637+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
638638
resnet_eps: float = 1e-6,
639639
resnet_time_scale_shift: str = "default",
640640
resnet_act_fn: str = "swish",
@@ -654,6 +654,10 @@ def __init__(
654654
self.num_attention_heads = num_attention_heads
655655
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
656656

657+
# support for variable transformer layers per block
658+
if isinstance(transformer_layers_per_block, int):
659+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
660+
657661
# there is always at least one resnet
658662
resnets = [
659663
ResnetBlock2D(
@@ -671,14 +675,14 @@ def __init__(
671675
]
672676
attentions = []
673677

674-
for _ in range(num_layers):
678+
for i in range(num_layers):
675679
if not dual_cross_attention:
676680
attentions.append(
677681
Transformer2DModel(
678682
num_attention_heads,
679683
in_channels // num_attention_heads,
680684
in_channels=in_channels,
681-
num_layers=transformer_layers_per_block,
685+
num_layers=transformer_layers_per_block[i],
682686
cross_attention_dim=cross_attention_dim,
683687
norm_num_groups=resnet_groups,
684688
use_linear_projection=use_linear_projection,
@@ -1018,7 +1022,7 @@ def __init__(
10181022
temb_channels: int,
10191023
dropout: float = 0.0,
10201024
num_layers: int = 1,
1021-
transformer_layers_per_block: int = 1,
1025+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
10221026
resnet_eps: float = 1e-6,
10231027
resnet_time_scale_shift: str = "default",
10241028
resnet_act_fn: str = "swish",
@@ -1041,6 +1045,8 @@ def __init__(
10411045

10421046
self.has_cross_attention = True
10431047
self.num_attention_heads = num_attention_heads
1048+
if isinstance(transformer_layers_per_block, int):
1049+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
10441050

10451051
for i in range(num_layers):
10461052
in_channels = in_channels if i == 0 else out_channels
@@ -1064,7 +1070,7 @@ def __init__(
10641070
num_attention_heads,
10651071
out_channels // num_attention_heads,
10661072
in_channels=out_channels,
1067-
num_layers=transformer_layers_per_block,
1073+
num_layers=transformer_layers_per_block[i],
10681074
cross_attention_dim=cross_attention_dim,
10691075
norm_num_groups=resnet_groups,
10701076
use_linear_projection=use_linear_projection,
@@ -2167,7 +2173,7 @@ def __init__(
21672173
resolution_idx: int = None,
21682174
dropout: float = 0.0,
21692175
num_layers: int = 1,
2170-
transformer_layers_per_block: int = 1,
2176+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
21712177
resnet_eps: float = 1e-6,
21722178
resnet_time_scale_shift: str = "default",
21732179
resnet_act_fn: str = "swish",
@@ -2190,6 +2196,9 @@ def __init__(
21902196
self.has_cross_attention = True
21912197
self.num_attention_heads = num_attention_heads
21922198

2199+
if isinstance(transformer_layers_per_block, int):
2200+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
2201+
21932202
for i in range(num_layers):
21942203
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
21952204
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -2214,7 +2223,7 @@ def __init__(
22142223
num_attention_heads,
22152224
out_channels // num_attention_heads,
22162225
in_channels=out_channels,
2217-
num_layers=transformer_layers_per_block,
2226+
num_layers=transformer_layers_per_block[i],
22182227
cross_attention_dim=cross_attention_dim,
22192228
norm_num_groups=resnet_groups,
22202229
use_linear_projection=use_linear_projection,

src/diffusers/models/unet_2d_condition.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from .modeling_utils import ModelMixin
4545
from .unet_2d_blocks import (
46+
UNetMidBlock2D,
4647
UNetMidBlock2DCrossAttn,
4748
UNetMidBlock2DSimpleCrossAttn,
4849
get_down_block,
@@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
8687
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
8788
The tuple of downsample blocks to use.
8889
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89-
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
90+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
9091
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
9192
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
9293
The tuple of upsample blocks to use.
@@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
105106
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
106107
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
107108
The dimension of the cross attention features.
108-
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
109+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
109110
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
110111
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111112
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
113+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
114+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
115+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
116+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
117+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112118
encoder_hid_dim (`int`, *optional*, defaults to None):
113119
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114120
dimension to `cross_attention_dim`.
@@ -142,9 +148,9 @@ class conditioning with `class_embed_type` equal to `None`.
142148
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143149
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144150
The dimension of `cond_proj` layer in the timestep embedding.
145-
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
146-
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
147-
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
151+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
152+
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
153+
*optional*): The dimension of the `class_labels` input when
148154
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149155
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150156
embeddings with the class embeddings.
@@ -184,7 +190,8 @@ def __init__(
184190
norm_num_groups: Optional[int] = 32,
185191
norm_eps: float = 1e-5,
186192
cross_attention_dim: Union[int, Tuple[int]] = 1280,
187-
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
193+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
194+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
188195
encoder_hid_dim: Optional[int] = None,
189196
encoder_hid_dim_type: Optional[str] = None,
190197
attention_head_dim: Union[int, Tuple[int]] = 8,
@@ -265,6 +272,10 @@ def __init__(
265272
raise ValueError(
266273
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
267274
)
275+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
276+
for layer_number_per_block in transformer_layers_per_block:
277+
if isinstance(layer_number_per_block, list):
278+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
268279

269280
# input
270281
conv_in_padding = (conv_in_kernel - 1) // 2
@@ -500,6 +511,19 @@ def __init__(
500511
only_cross_attention=mid_block_only_cross_attention,
501512
cross_attention_norm=cross_attention_norm,
502513
)
514+
elif mid_block_type == "UNetMidBlock2D":
515+
self.mid_block = UNetMidBlock2D(
516+
in_channels=block_out_channels[-1],
517+
temb_channels=blocks_time_embed_dim,
518+
dropout=dropout,
519+
num_layers=0,
520+
resnet_eps=norm_eps,
521+
resnet_act_fn=act_fn,
522+
output_scale_factor=mid_block_scale_factor,
523+
resnet_groups=norm_num_groups,
524+
resnet_time_scale_shift=resnet_time_scale_shift,
525+
add_attention=False,
526+
)
503527
elif mid_block_type is None:
504528
self.mid_block = None
505529
else:
@@ -513,7 +537,11 @@ def __init__(
513537
reversed_num_attention_heads = list(reversed(num_attention_heads))
514538
reversed_layers_per_block = list(reversed(layers_per_block))
515539
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
516-
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
540+
reversed_transformer_layers_per_block = (
541+
list(reversed(transformer_layers_per_block))
542+
if reverse_transformer_layers_per_block is None
543+
else reverse_transformer_layers_per_block
544+
)
517545
only_cross_attention = list(reversed(only_cross_attention))
518546

519547
output_channel = reversed_block_out_channels[0]
@@ -1062,14 +1090,18 @@ def forward(
10621090

10631091
# 4. mid
10641092
if self.mid_block is not None:
1065-
sample = self.mid_block(
1066-
sample,
1067-
emb,
1068-
encoder_hidden_states=encoder_hidden_states,
1069-
attention_mask=attention_mask,
1070-
cross_attention_kwargs=cross_attention_kwargs,
1071-
encoder_attention_mask=encoder_attention_mask,
1072-
)
1093+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1094+
sample = self.mid_block(
1095+
sample,
1096+
emb,
1097+
encoder_hidden_states=encoder_hidden_states,
1098+
attention_mask=attention_mask,
1099+
cross_attention_kwargs=cross_attention_kwargs,
1100+
encoder_attention_mask=encoder_attention_mask,
1101+
)
1102+
else:
1103+
sample = self.mid_block(sample, emb)
1104+
10731105
# To support T2I-Adapter-XL
10741106
if (
10751107
is_adapter

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
106106
feature_extractor ([`~transformers.CLIPImageProcessor`]):
107107
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
108108
"""
109+
109110
model_cpu_offload_seq = "text_encoder->unet->vae"
110111
_optional_components = ["safety_checker", "feature_extractor"]
111112
_exclude_from_cpu_offload = ["safety_checker"]

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline(
134134
feature_extractor ([`~transformers.CLIPImageProcessor`]):
135135
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
136136
"""
137+
137138
model_cpu_offload_seq = "text_encoder->unet->vae"
138139
_optional_components = ["safety_checker", "feature_extractor"]
139140
_exclude_from_cpu_offload = ["safety_checker"]

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
324324
if "disable_self_attentions" in unet_params:
325325
config["only_cross_attention"] = unet_params.disable_self_attentions
326326

327-
if "num_classes" in unet_params and type(unet_params.num_classes) == int:
327+
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
328328
config["num_class_embeds"] = unet_params.num_classes
329329

330330
if controlnet:

0 commit comments

Comments
 (0)