Skip to content

Commit e550163

Browse files
[Vae] Make sure all vae's work with latent diffusion models (huggingface#5880)
* add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * fix more * fix more * fix more * fix more * fix more * fix more
1 parent 20f0cbc commit e550163

21 files changed

+277
-112
lines changed

src/diffusers/models/autoencoder_asym_kl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def __init__(
108108
self.use_slicing = False
109109
self.use_tiling = False
110110

111+
self.register_to_config(block_out_channels=up_block_out_channels)
112+
self.register_to_config(force_upcast=False)
113+
111114
@apply_forward_hook
112115
def encode(
113116
self, x: torch.FloatTensor, return_dict: bool = True

src/diffusers/models/autoencoder_tiny.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def __init__(
148148
self.tile_sample_min_size = 512
149149
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
150150

151+
self.register_to_config(block_out_channels=decoder_block_out_channels)
152+
self.register_to_config(force_upcast=False)
153+
151154
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
152155
if isinstance(module, (EncoderTiny, DecoderTiny)):
153156
module.gradient_checkpointing = value

src/diffusers/models/consistency_decoder_vae.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(
138138
)
139139
self.decoder_scheduler = ConsistencyDecoderScheduler()
140140
self.register_to_config(block_out_channels=encoder_block_out_channels)
141+
self.register_to_config(force_upcast=False)
141142
self.register_buffer(
142143
"means",
143144
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@
7676

7777

7878
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
79-
def retrieve_latents(encoder_output, generator):
80-
if hasattr(encoder_output, "latent_dist"):
79+
def retrieve_latents(
80+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
81+
):
82+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
8183
return encoder_output.latent_dist.sample(generator)
84+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
85+
return encoder_output.latent_dist.mode()
8286
elif hasattr(encoder_output, "latents"):
8387
return encoder_output.latents
8488
else:

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,13 @@
9292

9393

9494
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
95-
def retrieve_latents(encoder_output, generator):
96-
if hasattr(encoder_output, "latent_dist"):
95+
def retrieve_latents(
96+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
97+
):
98+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
9799
return encoder_output.latent_dist.sample(generator)
100+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
101+
return encoder_output.latent_dist.mode()
98102
elif hasattr(encoder_output, "latents"):
99103
return encoder_output.latents
100104
else:

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,13 @@
104104

105105

106106
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
107-
def retrieve_latents(encoder_output, generator):
108-
if hasattr(encoder_output, "latent_dist"):
107+
def retrieve_latents(
108+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
109+
):
110+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
109111
return encoder_output.latent_dist.sample(generator)
112+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
113+
return encoder_output.latent_dist.mode()
110114
elif hasattr(encoder_output, "latents"):
111115
return encoder_output.latents
112116
else:

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@
5454
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5555

5656

57+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
58+
def retrieve_latents(
59+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
60+
):
61+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
62+
return encoder_output.latent_dist.sample(generator)
63+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
64+
return encoder_output.latent_dist.mode()
65+
elif hasattr(encoder_output, "latents"):
66+
return encoder_output.latents
67+
else:
68+
raise AttributeError("Could not access latents of provided encoder_output")
69+
70+
5771
EXAMPLE_DOC_STRING = """
5872
Examples:
5973
```py
@@ -824,12 +838,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
824838

825839
if isinstance(generator, list):
826840
image_latents = [
827-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
841+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
828842
for i in range(image.shape[0])
829843
]
830844
image_latents = torch.cat(image_latents, dim=0)
831845
else:
832-
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
846+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
833847

834848
if self.vae.config.force_upcast:
835849
self.vae.to(dtype)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,13 @@
133133

134134

135135
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
136-
def retrieve_latents(encoder_output, generator):
137-
if hasattr(encoder_output, "latent_dist"):
136+
def retrieve_latents(
137+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
138+
):
139+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
138140
return encoder_output.latent_dist.sample(generator)
141+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
142+
return encoder_output.latent_dist.mode()
139143
elif hasattr(encoder_output, "latents"):
140144
return encoder_output.latents
141145
else:

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,13 @@
4444

4545

4646
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
47-
def retrieve_latents(encoder_output, generator):
48-
if hasattr(encoder_output, "latent_dist"):
47+
def retrieve_latents(
48+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
49+
):
50+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
4951
return encoder_output.latent_dist.sample(generator)
52+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
53+
return encoder_output.latent_dist.mode()
5054
elif hasattr(encoder_output, "latents"):
5155
return encoder_output.latents
5256
else:

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@
3535

3636

3737
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
38-
def retrieve_latents(encoder_output, generator):
39-
if hasattr(encoder_output, "latent_dist"):
38+
def retrieve_latents(
39+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
40+
):
41+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
4042
return encoder_output.latent_dist.sample(generator)
43+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
44+
return encoder_output.latent_dist.mode()
4145
elif hasattr(encoder_output, "latents"):
4246
return encoder_output.latents
4347
else:

0 commit comments

Comments
 (0)