Skip to content

Commit bf406ea

Browse files
Correct consist dec (huggingface#5722)
* uP * Update src/diffusers/models/consistency_decoder_vae.py * uP * uP
1 parent 2fd4640 commit bf406ea

File tree

4 files changed

+96
-37
lines changed

4 files changed

+96
-37
lines changed

src/diffusers/models/autoencoder_asym_kl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def _decode(
138138
def decode(
139139
self,
140140
z: torch.FloatTensor,
141+
generator: Optional[torch.Generator] = None,
141142
image: Optional[torch.FloatTensor] = None,
142143
mask: Optional[torch.FloatTensor] = None,
143144
return_dict: bool = True,

src/diffusers/models/autoencoder_tiny.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from dataclasses import dataclass
17-
from typing import Tuple, Union
17+
from typing import Optional, Tuple, Union
1818

1919
import torch
2020

@@ -307,7 +307,9 @@ def encode(
307307
return AutoencoderTinyOutput(latents=output)
308308

309309
@apply_forward_hook
310-
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
310+
def decode(
311+
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
312+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
311313
if self.use_slicing and x.shape[0] > 1:
312314
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
313315
output = torch.cat(output)

src/diffusers/models/consistency_decoder_vae.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
6868
"""
6969

7070
@register_to_config
71-
def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels):
71+
def __init__(
72+
self,
73+
scaling_factor=0.18215,
74+
latent_channels=4,
75+
encoder_act_fn="silu",
76+
encoder_block_out_channels=(128, 256, 512, 512),
77+
encoder_double_z=True,
78+
encoder_down_block_types=(
79+
"DownEncoderBlock2D",
80+
"DownEncoderBlock2D",
81+
"DownEncoderBlock2D",
82+
"DownEncoderBlock2D",
83+
),
84+
encoder_in_channels=3,
85+
encoder_layers_per_block=2,
86+
encoder_norm_num_groups=32,
87+
encoder_out_channels=4,
88+
decoder_add_attention=False,
89+
decoder_block_out_channels=(320, 640, 1024, 1024),
90+
decoder_down_block_types=(
91+
"ResnetDownsampleBlock2D",
92+
"ResnetDownsampleBlock2D",
93+
"ResnetDownsampleBlock2D",
94+
"ResnetDownsampleBlock2D",
95+
),
96+
decoder_downsample_padding=1,
97+
decoder_in_channels=7,
98+
decoder_layers_per_block=3,
99+
decoder_norm_eps=1e-05,
100+
decoder_norm_num_groups=32,
101+
decoder_num_train_timesteps=1024,
102+
decoder_out_channels=6,
103+
decoder_resnet_time_scale_shift="scale_shift",
104+
decoder_time_embedding_type="learned",
105+
decoder_up_block_types=(
106+
"ResnetUpsampleBlock2D",
107+
"ResnetUpsampleBlock2D",
108+
"ResnetUpsampleBlock2D",
109+
"ResnetUpsampleBlock2D",
110+
),
111+
):
72112
super().__init__()
73-
self.encoder = Encoder(**encoder_args)
74-
self.decoder_unet = UNet2DModel(**decoder_args)
113+
self.encoder = Encoder(
114+
act_fn=encoder_act_fn,
115+
block_out_channels=encoder_block_out_channels,
116+
double_z=encoder_double_z,
117+
down_block_types=encoder_down_block_types,
118+
in_channels=encoder_in_channels,
119+
layers_per_block=encoder_layers_per_block,
120+
norm_num_groups=encoder_norm_num_groups,
121+
out_channels=encoder_out_channels,
122+
)
123+
124+
self.decoder_unet = UNet2DModel(
125+
add_attention=decoder_add_attention,
126+
block_out_channels=decoder_block_out_channels,
127+
down_block_types=decoder_down_block_types,
128+
downsample_padding=decoder_downsample_padding,
129+
in_channels=decoder_in_channels,
130+
layers_per_block=decoder_layers_per_block,
131+
norm_eps=decoder_norm_eps,
132+
norm_num_groups=decoder_norm_num_groups,
133+
num_train_timesteps=decoder_num_train_timesteps,
134+
out_channels=decoder_out_channels,
135+
resnet_time_scale_shift=decoder_resnet_time_scale_shift,
136+
time_embedding_type=decoder_time_embedding_type,
137+
up_block_types=decoder_up_block_types,
138+
)
75139
self.decoder_scheduler = ConsistencyDecoderScheduler()
140+
self.register_to_config(block_out_channels=encoder_block_out_channels)
76141
self.register_buffer(
77142
"means",
78143
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],

tests/models/test_models_vae.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -303,39 +303,30 @@ def output_shape(self):
303303
@property
304304
def init_dict(self):
305305
return {
306-
"encoder_args": {
307-
"block_out_channels": [32, 64],
308-
"in_channels": 3,
309-
"out_channels": 4,
310-
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
311-
},
312-
"decoder_args": {
313-
"act_fn": "silu",
314-
"add_attention": False,
315-
"block_out_channels": [32, 64],
316-
"down_block_types": [
317-
"ResnetDownsampleBlock2D",
318-
"ResnetDownsampleBlock2D",
319-
],
320-
"downsample_padding": 1,
321-
"downsample_type": "conv",
322-
"dropout": 0.0,
323-
"in_channels": 7,
324-
"layers_per_block": 1,
325-
"norm_eps": 1e-05,
326-
"norm_num_groups": 32,
327-
"num_train_timesteps": 1024,
328-
"out_channels": 6,
329-
"resnet_time_scale_shift": "scale_shift",
330-
"time_embedding_type": "learned",
331-
"up_block_types": [
332-
"ResnetUpsampleBlock2D",
333-
"ResnetUpsampleBlock2D",
334-
],
335-
"upsample_type": "conv",
336-
},
306+
"encoder_block_out_channels": [32, 64],
307+
"encoder_in_channels": 3,
308+
"encoder_out_channels": 4,
309+
"encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
310+
"decoder_add_attention": False,
311+
"decoder_block_out_channels": [32, 64],
312+
"decoder_down_block_types": [
313+
"ResnetDownsampleBlock2D",
314+
"ResnetDownsampleBlock2D",
315+
],
316+
"decoder_downsample_padding": 1,
317+
"decoder_in_channels": 7,
318+
"decoder_layers_per_block": 1,
319+
"decoder_norm_eps": 1e-05,
320+
"decoder_norm_num_groups": 32,
321+
"decoder_num_train_timesteps": 1024,
322+
"decoder_out_channels": 6,
323+
"decoder_resnet_time_scale_shift": "scale_shift",
324+
"decoder_time_embedding_type": "learned",
325+
"decoder_up_block_types": [
326+
"ResnetUpsampleBlock2D",
327+
"ResnetUpsampleBlock2D",
328+
],
337329
"scaling_factor": 1,
338-
"block_out_channels": [32, 64],
339330
"latent_channels": 4,
340331
}
341332

0 commit comments

Comments
 (0)