Skip to content

Commit 443aa14

Browse files
Fix Tiling in ConsistencyDecoderVAE (huggingface#7290)
* Fix typos * Add docstring to `decode` method in `ConsistencyDecoderVAE` * Fix tiling * Enable tiled VAE decoding with customizable tile sample size and overlap factor * Revert "Enable tiled VAE decoding with customizable tile sample size and overlap factor" This reverts commit 1810496. * Add VAE tiling test for `ConsistencyDecoderVAE` --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 288632a commit 443aa14

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

src/diffusers/models/autoencoders/consistency_decoder_vae.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
6363
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
6464
... ).to("cuda")
6565
66-
>>> pipe("horse", generator=torch.manual_seed(0)).images
66+
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
67+
>>> image
6768
```
6869
"""
6970

@@ -72,6 +73,7 @@ def __init__(
7273
self,
7374
scaling_factor: float = 0.18215,
7475
latent_channels: int = 4,
76+
sample_size: int = 32,
7577
encoder_act_fn: str = "silu",
7678
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
7779
encoder_double_z: bool = True,
@@ -153,6 +155,16 @@ def __init__(
153155
self.use_slicing = False
154156
self.use_tiling = False
155157

158+
# only relevant if vae tiling is enabled
159+
self.tile_sample_min_size = self.config.sample_size
160+
sample_size = (
161+
self.config.sample_size[0]
162+
if isinstance(self.config.sample_size, (list, tuple))
163+
else self.config.sample_size
164+
)
165+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
166+
self.tile_overlap_factor = 0.25
167+
156168
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
157169
def enable_tiling(self, use_tiling: bool = True):
158170
r"""
@@ -272,7 +284,7 @@ def encode(
272284
Args:
273285
x (`torch.FloatTensor`): Input batch of images.
274286
return_dict (`bool`, *optional*, defaults to `True`):
275-
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
287+
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
276288
tuple.
277289
278290
Returns:
@@ -305,6 +317,19 @@ def decode(
305317
return_dict: bool = True,
306318
num_inference_steps: int = 2,
307319
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
320+
"""
321+
Decodes the input latent vector `z` using the consistency decoder VAE model.
322+
323+
Args:
324+
z (torch.FloatTensor): The input latent vector.
325+
generator (Optional[torch.Generator]): The random number generator. Default is None.
326+
return_dict (bool): Whether to return the output as a dictionary. Default is True.
327+
num_inference_steps (int): The number of inference steps. Default is 2.
328+
329+
Returns:
330+
Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output.
331+
332+
"""
308333
z = (z * self.config.scaling_factor - self.means) / self.stds
309334

310335
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
@@ -345,7 +370,9 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
345370
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
346371
return b
347372

348-
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
373+
def tiled_encode(
374+
self, x: torch.FloatTensor, return_dict: bool = True
375+
) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
349376
r"""Encode a batch of images using a tiled encoder.
350377
351378
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several

tests/models/autoencoders/test_models_vae.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,33 @@ def test_sd_f16(self):
11161116
)
11171117

11181118
assert torch_all_close(actual_output, expected_output, atol=5e-3)
1119+
1120+
def test_vae_tiling(self):
1121+
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
1122+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
1123+
pipe.to(torch_device)
1124+
pipe.set_progress_bar_config(disable=None)
1125+
1126+
out_1 = pipe(
1127+
"horse",
1128+
num_inference_steps=2,
1129+
output_type="pt",
1130+
generator=torch.Generator("cpu").manual_seed(0),
1131+
).images[0]
1132+
1133+
# make sure tiled vae decode yields the same result
1134+
pipe.enable_vae_tiling()
1135+
out_2 = pipe(
1136+
"horse",
1137+
num_inference_steps=2,
1138+
output_type="pt",
1139+
generator=torch.Generator("cpu").manual_seed(0),
1140+
).images[0]
1141+
1142+
assert torch_all_close(out_1, out_2, atol=5e-3)
1143+
1144+
# test that tiled decode works with various shapes
1145+
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
1146+
for shape in shapes:
1147+
image = torch.zeros(shape, device=torch_device)
1148+
pipe.vae.decode(image)

0 commit comments

Comments
 (0)