Skip to content

Asymmetric vqgan #3956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 20, 2023
Merged

Conversation

cross-attention
Copy link
Contributor

@cross-attention cross-attention commented Jul 5, 2023

Added AsymmetricAutoencoderKL for Stable Diffusion Inpainting

Added AsymmetricAutoencoderKL model from Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632

Added its support in StableDiffusionInpaintPipeline

Before submitting

Who can review?

@patrickvonplaten @sayakpaul

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul
Copy link
Member

Hi @cross-attention!

Thanks for your PR! Could you maybe share some results with this new Autoencoder? That will help us to better evaluate this.

Maybe if you could do:

from diffusers AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline

vae = AsymmetricAutoencoderKL.from_pretrained("the-ckpt-id")
pipeline = StableDiffusionInpaintPipeline.from_pretrained(ckpt_id, vae=vae).to("cuda")

...

That would be great!

@cross-attention
Copy link
Contributor Author

@sayakpaul
Sure! Here are the results with runwayml/stable-diffusion-inpainting (prompt is "a closeup photo of a male person in a black t-shirt on a solid yellow background")
There are 4 groups with 9 examples in each with the following VAE setup

  1. default
  2. stabilityai/sd-vae-ft-mse
  3. AsymmetricAutoencoderKL x1.5 scale
  4. AsymmetricAutoencoderKL x2 scale
    https://i.imgur.com/arZi2vT.jpg

@sayakpaul
Copy link
Member

That looks great! Thank you! Which checkpoint did you use for the final two cases?

AsymmetricAutoencoderKL x1.5 scale and AsymmetricAutoencoderKL x2 scale?

Could you maybe provide us some code snippets?

@cross-attention
Copy link
Contributor Author

I used the original checkpoints from https://github.com/buxiangzhiren/Asymmetric_VQGAN/
To match the keys I used the following code

import torch
from diffusers import AsymmetricAutoencoderKL

# x1.5
ckpt = torch.load("./checkpoints/larger1.5.ckpt", map_location="cpu")
vae = AsymmetricAutoencoderKL(
    in_channels = 3,
    out_channels = 3,
    down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
    down_block_out_channels = (128, 256, 512, 512),
    layers_per_down_block = 2,
    up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
    up_block_out_channels = (192, 384, 768, 768),
    layers_per_up_block = 3,
    act_fn = "silu",
    latent_channels = 4,
    norm_num_groups = 32,
    sample_size = 256,
    scaling_factor = 0.18215,
)
# x2
ckpt = torch.load("./checkpoints/larger2.ckpt", map_location="cpu")
vae = AsymmetricAutoencoderKL(
    in_channels = 3,
    out_channels = 3,
    down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
    down_block_out_channels = (128, 256, 512, 512),
    layers_per_down_block = 2,
    up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
    up_block_out_channels = (256, 512, 1024, 1024),
    layers_per_up_block = 5,
    act_fn = "silu",
    latent_channels = 4,
    norm_num_groups = 32,
    sample_size = 256,
    scaling_factor = 0.18215,
)

# match keys
enc_dict = {
    k
    .replace("encoder.down.", "encoder.down_blocks.")
    .replace("encoder.mid.", "encoder.mid_block.")
    .replace("encoder.norm_out.", "encoder.conv_norm_out.")
    .replace(".downsample.", ".downsamplers.0.")
    .replace(".nin_shortcut.", ".conv_shortcut.")
    .replace(".block.", ".resnets.")
    .replace(".block_1.", ".resnets.0.")
    .replace(".block_2.", ".resnets.1.")
    .replace(".attn_1.k.", ".attentions.0.to_k.")
    .replace(".attn_1.q.", ".attentions.0.to_q.")
    .replace(".attn_1.v.", ".attentions.0.to_v.")
    .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
    .replace(".attn_1.norm.", ".attentions.0.group_norm.")
    :
    v
    for k, v in ckpt["state_dict"].items() if k.startswith("encoder.")
}
for k in enc_dict.keys():
    if (
        k.startswith("encoder.mid_block.attentions.0") and
        k.endswith("weight") and
        ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
    ):
        enc_dict[k] = enc_dict[k][:, :, 0, 0]
dec_dict = {
    k
    .replace(".norm_out.", ".conv_norm_out.")
    .replace(".up.0.", ".up_blocks.3.")
    .replace(".up.1.", ".up_blocks.2.")
    .replace(".up.2.", ".up_blocks.1.")
    .replace(".up.3.", ".up_blocks.0.")
    .replace(".block.", ".resnets.")
    .replace("mid", "mid_block")
    .replace(".0.upsample.", ".0.upsamplers.0.")
    .replace(".1.upsample.", ".1.upsamplers.0.")
    .replace(".2.upsample.", ".2.upsamplers.0.")
    .replace(".nin_shortcut.", ".conv_shortcut.")
    .replace(".block_1.", ".resnets.0.")
    .replace(".block_2.", ".resnets.1.")
    .replace(".attn_1.k.", ".attentions.0.to_k.")
    .replace(".attn_1.q.", ".attentions.0.to_q.")
    .replace(".attn_1.v.", ".attentions.0.to_v.")
    .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
    .replace(".attn_1.norm.", ".attentions.0.group_norm.")
    :
    v
    for k, v in ckpt["state_dict"].items() if (
        k.startswith("decoder.") and
        not k.startswith("decoder.up_layers.") and
        not k.startswith("decoder.encoder.")
    )
}
for k in dec_dict.keys():
    if (
        k.startswith("decoder.mid_block.attentions.0") and
        k.endswith("weight") and
        ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
    ):
        dec_dict[k] = dec_dict[k][:, :, 0, 0]
cond_enc_dict = {
    k
    .replace("decoder.up_layers.", "decoder.condition_encoder.up_layers.")
    .replace("decoder.encoder.", "decoder.condition_encoder.")
    :
    v
    for k, v in ckpt["state_dict"].items() if (
        k.startswith("decoder.up_layers.") or
        k.startswith("decoder.encoder.")
    )
}
quant_conv_dict = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("quant_conv.")}
post_quant_conv_dict = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("post_quant_conv.")}

vae.load_state_dict({**quant_conv_dict, **post_quant_conv_dict, **enc_dict, **dec_dict, **cond_enc_dict})

@sayakpaul
Copy link
Member

Great this is superb stuff! From my end, I think the PR is already in good shape. I think we need the following:

  • A conversion script to get the original checkpoints converted in the diffusers format -- looks like you already have it. You can check how we structure these scripts here: https://github.com/huggingface/diffusers/tree/main/scripts.
  • Host the converted checkpoint on the Hugging Face Hub under https://hf.co/buxiangzhiren (which doesn't yet exist). So, we likely host them under your HF profile. Then once https://hf.co/buxiangzhiren is available we can transfer the converted checkpoint there with a nice model card including the usage example.
  • Add tests.
  • Document the usage.

Let me know anything is unclear here :-) More than happy to help.

@cross-attention
Copy link
Contributor Author

@sayakpaul
Updates are ready!
Models:
x1.5 https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5
x2 https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cross-attention
Copy link
Contributor Author

@sayakpaul

  • added model_card
  • fixed tests
  • updated doc

Comment on lines +13 to +14
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5)
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we have https://huggingface.co/buxiangzhiren and we have made contact with the author, I think we can transfer the repositories. @patrickvonplaten could you please help?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to transfer them once merged

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
@cross-attention
Copy link
Contributor Author

@sayakpaul @patrickvonplaten
added comments fixes

Co-authored-by: Patrick von Platen <[email protected]>
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL):
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe compute the init_image_condition only here? Since it's not needed for the "normal" VAE?

Comment on lines 985 to 987
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
init_image_condition = init_image.clone()
init_image = self._encode_vae_image(init_image, generator=generator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
init_image_condition = init_image.clone()
init_image = self._encode_vae_image(init_image, generator=generator)

I think this is only needed when the decoder is of type AsymmetricAutoencoderKL - should we maybe add it further down, e.g. here: https://github.com/huggingface/diffusers/pull/3956/files#r1266533587

This way we can save an image encoding step

@@ -173,6 +173,56 @@ def test_output_pretrained(self):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))


class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there I think! Can we also add some tests to https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py to make sure the inpainting pipeline works as expected? :-)

@cross-attention
Copy link
Contributor Author

@patrickvonplaten
done!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool PR!

for l in range(len(self.layers)):
layer = self.layers[l]
x = layer(x)
out[str(tuple(x.shape))] = x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's happening here? We use the string of shapes as keys to store the encoded condition and use them to match with decoder blocks?

Is it possible that two layer outputs have same shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup
they are different

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't like this - if it's possible to configure the model in a way that output remain same shape between two layers we will have a problem here

cc @patrickvonplaten @sayakpaul let me know what you think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's possible to configure the model in a way that output remain same shape between two layers we will have a problem here

Valid concern. If there's a possibility of the underlying model to do this, then, yes, let's try to rejig this part.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree here as well - @cross-attention could we maybe do one more round of refactoring here:

Would something like this work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with optional mask and image we can also use AsymmetricAutoencoderKL for text2image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask in that case would be a None right? And it seems like AsymmetricAutoencoderKL already handles this case?

If so, it might be good to add a test to clarify that (potentially in a future PR).

@sayakpaul
Copy link
Member

@patrickvonplaten @yiyixuxu what's pending in this PR?

From what I see this is the only that's pending: #3956 (comment).

Let's try to ship this soon.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet! Thanks for iterating.


self.gradient_checkpointing = False

def forward(self, z, image=None, mask=None, latent_embeds=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly image and mask can't really be None no? Can we maybe force the user to pass both image and mask here -> this would make the code much easier to follow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they can be None
in this case decoder will work without mask/image condition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case does it also yield improvements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the quality is almost equal in text2image setup according to the paper

Comment on lines +467 to +470
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)

Can we move this code to the condition encoder forward method instead? I think it would be better placed there

@cross-attention
Copy link
Contributor Author

@patrickvonplaten
let's clearly define all remaining steps, having in mind last comments, please

@patrickvonplaten
Copy link
Contributor

Let's merge this PR for now as is, but if usage goes up of Asymmetric, I'd like to do a refactor here where we don't return a dict in the form <image_shape: image_out> but just a tuple of type <image_out> instead.

By default we prefer the design of return non-mutable tuples in diffusers compared to dicts. But ok for now.

@patrickvonplaten patrickvonplaten merged commit 07f1fbb into huggingface:main Jul 20, 2023
@sayakpaul
Copy link
Member

Let's merge the Hun checkpoints too, @patrickvonplaten.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jul 20, 2023

orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* added AsymmetricAutoencoderKL

* fixed copies+dummy

* added script to convert original asymmetric vqgan

* added docs

* updated docs

* fixed style

* fixes, added tests

* update doc

* fixed doc

* fixed tests

* naming

Co-authored-by: Sayak Paul <[email protected]>

* naming

Co-authored-by: Sayak Paul <[email protected]>

* udpated code example

* updated doc

* comments fixes

* added docstring

Co-authored-by: Patrick von Platen <[email protected]>

* comments fixes

* added inpaint pipeline tests

* comment suggestion: delete method

* yet another fixes

---------

Co-authored-by: Ruslan Vorovchenko <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* added AsymmetricAutoencoderKL

* fixed copies+dummy

* added script to convert original asymmetric vqgan

* added docs

* updated docs

* fixed style

* fixes, added tests

* update doc

* fixed doc

* fixed tests

* naming

Co-authored-by: Sayak Paul <[email protected]>

* naming

Co-authored-by: Sayak Paul <[email protected]>

* udpated code example

* updated doc

* comments fixes

* added docstring

Co-authored-by: Patrick von Platen <[email protected]>

* comments fixes

* added inpaint pipeline tests

* comment suggestion: delete method

* yet another fixes

---------

Co-authored-by: Ruslan Vorovchenko <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* added AsymmetricAutoencoderKL

* fixed copies+dummy

* added script to convert original asymmetric vqgan

* added docs

* updated docs

* fixed style

* fixes, added tests

* update doc

* fixed doc

* fixed tests

* naming

Co-authored-by: Sayak Paul <[email protected]>

* naming

Co-authored-by: Sayak Paul <[email protected]>

* udpated code example

* updated doc

* comments fixes

* added docstring

Co-authored-by: Patrick von Platen <[email protected]>

* comments fixes

* added inpaint pipeline tests

* comment suggestion: delete method

* yet another fixes

---------

Co-authored-by: Ruslan Vorovchenko <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* added AsymmetricAutoencoderKL

* fixed copies+dummy

* added script to convert original asymmetric vqgan

* added docs

* updated docs

* fixed style

* fixes, added tests

* update doc

* fixed doc

* fixed tests

* naming

Co-authored-by: Sayak Paul <[email protected]>

* naming

Co-authored-by: Sayak Paul <[email protected]>

* udpated code example

* updated doc

* comments fixes

* added docstring

Co-authored-by: Patrick von Platen <[email protected]>

* comments fixes

* added inpaint pipeline tests

* comment suggestion: delete method

* yet another fixes

---------

Co-authored-by: Ruslan Vorovchenko <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* added AsymmetricAutoencoderKL

* fixed copies+dummy

* added script to convert original asymmetric vqgan

* added docs

* updated docs

* fixed style

* fixes, added tests

* update doc

* fixed doc

* fixed tests

* naming

Co-authored-by: Sayak Paul <[email protected]>

* naming

Co-authored-by: Sayak Paul <[email protected]>

* udpated code example

* updated doc

* comments fixes

* added docstring

Co-authored-by: Patrick von Platen <[email protected]>

* comments fixes

* added inpaint pipeline tests

* comment suggestion: delete method

* yet another fixes

---------

Co-authored-by: Ruslan Vorovchenko <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants