@@ -62,6 +62,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
6262 If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
6363 can be fine-tuned / trained to a lower range without loosing too much precision in which case
6464 `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
65+ mid_block_add_attention (`bool`, *optional*, default to `True`):
66+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
67+ mid_block will only have resnet blocks
6568 """
6669
6770 _supports_gradient_checkpointing = True
@@ -87,6 +90,7 @@ def __init__(
8790 force_upcast : float = True ,
8891 use_quant_conv : bool = True ,
8992 use_post_quant_conv : bool = True ,
93+ mid_block_add_attention : bool = True ,
9094 ):
9195 super ().__init__ ()
9296
@@ -100,6 +104,7 @@ def __init__(
100104 act_fn = act_fn ,
101105 norm_num_groups = norm_num_groups ,
102106 double_z = True ,
107+ mid_block_add_attention = mid_block_add_attention ,
103108 )
104109
105110 # pass init params to Decoder
@@ -111,6 +116,7 @@ def __init__(
111116 layers_per_block = layers_per_block ,
112117 norm_num_groups = norm_num_groups ,
113118 act_fn = act_fn ,
119+ mid_block_add_attention = mid_block_add_attention ,
114120 )
115121
116122 self .quant_conv = nn .Conv2d (2 * latent_channels , 2 * latent_channels , 1 ) if use_quant_conv else None
0 commit comments