Skip to content

Commit d71ecad

Browse files
authored
denormalize latents with the mean and std if available (huggingface#7111)
* denormalize latents with the mean and std if available * fix denormalize * add latent mean and std in vae config * address sayak's comment
1 parent ac49f97 commit d71ecad

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def __init__(
8080
norm_num_groups: int = 32,
8181
sample_size: int = 32,
8282
scaling_factor: float = 0.18215,
83+
latents_mean: Optional[Tuple[float]] = None,
84+
latents_std: Optional[Tuple[float]] = None,
8385
force_upcast: float = True,
8486
):
8587
super().__init__()

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,22 @@ def __call__(
13131313
self.upcast_vae()
13141314
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
13151315

1316-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1316+
# unscale/denormalize the latents
1317+
# denormalize with the mean and std if available and not None
1318+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1319+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1320+
if has_latents_mean and has_latents_std:
1321+
latents_mean = (
1322+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1323+
)
1324+
latents_std = (
1325+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1326+
)
1327+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1328+
else:
1329+
latents = latents / self.vae.config.scaling_factor
1330+
1331+
image = self.vae.decode(latents, return_dict=False)[0]
13171332

13181333
# cast back to fp16 if needed
13191334
if needs_upcasting:

0 commit comments

Comments
 (0)