Skip to content

[BUG] fixes in kadinsky pipeline #11080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 21, 2025

Conversation

ishan-modi
Copy link
Contributor

What does this PR do?

Fixes #11060

Who can review?

@DN6

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 8, 2025

thanks @ishan-modi !
are you able to run the docstring examples for these pipelines and see if the outputs are same in this branch vs on main?

@ishan-modi
Copy link
Contributor Author

ishan-modi commented Apr 8, 2025

I tried running both branches for kadinsky3 and following are the results

they are slightly different potentially due to usage of quantization/balanced_strategy(device_map) because of limited GPU, not sure though

main fixes-issue-11060 diff

Also preprocess and postprocess are exactly same as before

import torch
import numpy as np
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import load_image
from diffusers.pipelines.pipeline_utils import numpy_to_pil

input_image = load_image(
 "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
)
midpoint_image = torch.randn(1, 4, 64, 64)
image_processor = VaeImageProcessor(
    vae_scale_factor=2**3,
    vae_latent_channels=4,
    resample="bicubic",
    reducing_gap=1
)

def preprocess(image, image_processor=None, branch="default"):
    if branch == "main":
        arr = np.array(image.convert("RGB"))
        arr = arr.astype(np.float32) / 127.5 - 1
        arr = np.transpose(arr, [2, 0, 1])
        image = torch.from_numpy(arr).unsqueeze(0)
        return image
    
    image = image_processor.preprocess(image)
    return image

def postprocess(image, image_processor=None, branch="default"):
    if branch=="main":
        image = image * 0.5 + 0.5
        image = image.clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        image = numpy_to_pil(image)[0]
        return image
    
    image = image_processor.postprocess(image)[0]
    return image

if torch.equal(preprocess(input_image, image_processor), preprocess(input_image, branch="main")):
    print("Preprocessed images are exactly the same.")

if list(postprocess(midpoint_image, image_processor).getdata()) == list(postprocess(midpoint_image, branch="main").getdata()):
    print("Postprocessed images are exactly the same.")

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 8, 2025

I think there are some kandinsky test failure that's potentially related https://github.com/huggingface/diffusers/actions/runs/14330695758/job/40203664079?pr=11080#step:6:33932

@ishan-modi
Copy link
Contributor Author

Thanks for the review, fixed the tests !

resample="bicubic",
reducing_gap=1,
)
kwargs = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh we can just give them a default value like this

movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else ..
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else ..

to be consistent with how it is handled in other pipelines, for example https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L218

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for all pipelines.

else:
image = latents
image = self.image_processor.postprocess(image, output_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks output_type == "latent". image_processor.postprocess should not be applied to latent output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, postprocess function doesn't do anything if the output_type is "latent", see here.

Let me know if I am missing anything

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even so, the code style does not match other pipelines, please make the requested changes to keep consistency with code styling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, made the change let me know if it looks good

Comment on lines 370 to 373
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed. Refer to other pipelines for an example.

if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

@ishan-modi ishan-modi requested a review from hlky April 10, 2025 07:19
@ishan-modi
Copy link
Contributor Author

@hlky gentle ping

@yiyixuxu yiyixuxu merged commit 79ea8eb into huggingface:main Apr 21, 2025
11 of 12 checks passed
@yiyixuxu
Copy link
Collaborator

thanks a lot @ishan-modi

@ishan-modi ishan-modi deleted the fixes-issue-11060 branch April 22, 2025 03:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

prepare_image in Kandinsky pipelines doesn't support torch.Tensor
4 participants