Skip to content

Commit 363d1ab

Browse files
authored
Wan VAE move scaling to pipeline (#10998)
1 parent 6a0137e commit 363d1ab

File tree

3 files changed

+31
-13
lines changed

3 files changed

+31
-13
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -715,11 +715,6 @@ def __init__(
715715
) -> None:
716716
super().__init__()
717717

718-
# Store normalization parameters as tensors
719-
self.mean = torch.tensor(latents_mean)
720-
self.std = torch.tensor(latents_std)
721-
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
722-
723718
self.z_dim = z_dim
724719
self.temperal_downsample = temperal_downsample
725720
self.temperal_upsample = temperal_downsample[::-1]
@@ -751,7 +746,6 @@ def _count_conv3d(model):
751746
self._enc_feat_map = [None] * self._enc_conv_num
752747

753748
def _encode(self, x: torch.Tensor) -> torch.Tensor:
754-
scale = self.scale.type_as(x)
755749
self.clear_cache()
756750
## cache
757751
t = x.shape[2]
@@ -770,8 +764,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
770764

771765
enc = self.quant_conv(out)
772766
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
773-
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
774-
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
775767
enc = torch.cat([mu, logvar], dim=1)
776768
self.clear_cache()
777769
return enc
@@ -798,10 +790,8 @@ def encode(
798790
return (posterior,)
799791
return AutoencoderKLOutput(latent_dist=posterior)
800792

801-
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
793+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
802794
self.clear_cache()
803-
# z: [b,c,t,h,w]
804-
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
805795

806796
iter_ = z.shape[2]
807797
x = self.post_quant_conv(z)
@@ -835,8 +825,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
835825
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
836826
returned.
837827
"""
838-
scale = self.scale.type_as(z)
839-
decoded = self._decode(z, scale).sample
828+
decoded = self._decode(z).sample
840829
if not return_dict:
841830
return (decoded,)
842831

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,15 @@ def __call__(
563563

564564
if not output_type == "latent":
565565
latents = latents.to(self.vae.dtype)
566+
latents_mean = (
567+
torch.tensor(self.vae.config.latents_mean)
568+
.view(1, self.vae.config.z_dim, 1, 1, 1)
569+
.to(latents.device, latents.dtype)
570+
)
571+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
572+
latents.device, latents.dtype
573+
)
574+
latents = latents / latents_std + latents_mean
566575
video = self.vae.decode(latents, return_dict=False)[0]
567576
video = self.video_processor.postprocess_video(video, output_type=output_type)
568577
else:

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,17 @@ def prepare_latents(
392392
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
393393
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
394394

395+
latents_mean = (
396+
torch.tensor(self.vae.config.latents_mean)
397+
.view(1, self.vae.config.z_dim, 1, 1, 1)
398+
.to(latents.device, latents.dtype)
399+
)
400+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
401+
latents.device, latents.dtype
402+
)
403+
404+
latent_condition = (latent_condition - latents_mean) * latents_std
405+
395406
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
396407
mask_lat_size[:, :, list(range(1, num_frames))] = 0
397408
first_frame_mask = mask_lat_size[:, :, 0:1]
@@ -654,6 +665,15 @@ def __call__(
654665

655666
if not output_type == "latent":
656667
latents = latents.to(self.vae.dtype)
668+
latents_mean = (
669+
torch.tensor(self.vae.config.latents_mean)
670+
.view(1, self.vae.config.z_dim, 1, 1, 1)
671+
.to(latents.device, latents.dtype)
672+
)
673+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
674+
latents.device, latents.dtype
675+
)
676+
latents = latents / latents_std + latents_mean
657677
video = self.vae.decode(latents, return_dict=False)[0]
658678
video = self.video_processor.postprocess_video(video, output_type=output_type)
659679
else:

0 commit comments

Comments
 (0)