-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Asymmetric vqgan #3956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Asymmetric vqgan #3956
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hi @cross-attention! Thanks for your PR! Could you maybe share some results with this new Autoencoder? That will help us to better evaluate this. Maybe if you could do: from diffusers AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
vae = AsymmetricAutoencoderKL.from_pretrained("the-ckpt-id")
pipeline = StableDiffusionInpaintPipeline.from_pretrained(ckpt_id, vae=vae).to("cuda")
... That would be great! |
@sayakpaul
|
That looks great! Thank you! Which checkpoint did you use for the final two cases?
Could you maybe provide us some code snippets? |
I used the original checkpoints from https://github.com/buxiangzhiren/Asymmetric_VQGAN/ import torch
from diffusers import AsymmetricAutoencoderKL
# x1.5
ckpt = torch.load("./checkpoints/larger1.5.ckpt", map_location="cpu")
vae = AsymmetricAutoencoderKL(
in_channels = 3,
out_channels = 3,
down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
down_block_out_channels = (128, 256, 512, 512),
layers_per_down_block = 2,
up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
up_block_out_channels = (192, 384, 768, 768),
layers_per_up_block = 3,
act_fn = "silu",
latent_channels = 4,
norm_num_groups = 32,
sample_size = 256,
scaling_factor = 0.18215,
)
# x2
ckpt = torch.load("./checkpoints/larger2.ckpt", map_location="cpu")
vae = AsymmetricAutoencoderKL(
in_channels = 3,
out_channels = 3,
down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
down_block_out_channels = (128, 256, 512, 512),
layers_per_down_block = 2,
up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
up_block_out_channels = (256, 512, 1024, 1024),
layers_per_up_block = 5,
act_fn = "silu",
latent_channels = 4,
norm_num_groups = 32,
sample_size = 256,
scaling_factor = 0.18215,
)
# match keys
enc_dict = {
k
.replace("encoder.down.", "encoder.down_blocks.")
.replace("encoder.mid.", "encoder.mid_block.")
.replace("encoder.norm_out.", "encoder.conv_norm_out.")
.replace(".downsample.", ".downsamplers.0.")
.replace(".nin_shortcut.", ".conv_shortcut.")
.replace(".block.", ".resnets.")
.replace(".block_1.", ".resnets.0.")
.replace(".block_2.", ".resnets.1.")
.replace(".attn_1.k.", ".attentions.0.to_k.")
.replace(".attn_1.q.", ".attentions.0.to_q.")
.replace(".attn_1.v.", ".attentions.0.to_v.")
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
:
v
for k, v in ckpt["state_dict"].items() if k.startswith("encoder.")
}
for k in enc_dict.keys():
if (
k.startswith("encoder.mid_block.attentions.0") and
k.endswith("weight") and
("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
):
enc_dict[k] = enc_dict[k][:, :, 0, 0]
dec_dict = {
k
.replace(".norm_out.", ".conv_norm_out.")
.replace(".up.0.", ".up_blocks.3.")
.replace(".up.1.", ".up_blocks.2.")
.replace(".up.2.", ".up_blocks.1.")
.replace(".up.3.", ".up_blocks.0.")
.replace(".block.", ".resnets.")
.replace("mid", "mid_block")
.replace(".0.upsample.", ".0.upsamplers.0.")
.replace(".1.upsample.", ".1.upsamplers.0.")
.replace(".2.upsample.", ".2.upsamplers.0.")
.replace(".nin_shortcut.", ".conv_shortcut.")
.replace(".block_1.", ".resnets.0.")
.replace(".block_2.", ".resnets.1.")
.replace(".attn_1.k.", ".attentions.0.to_k.")
.replace(".attn_1.q.", ".attentions.0.to_q.")
.replace(".attn_1.v.", ".attentions.0.to_v.")
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
:
v
for k, v in ckpt["state_dict"].items() if (
k.startswith("decoder.") and
not k.startswith("decoder.up_layers.") and
not k.startswith("decoder.encoder.")
)
}
for k in dec_dict.keys():
if (
k.startswith("decoder.mid_block.attentions.0") and
k.endswith("weight") and
("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
):
dec_dict[k] = dec_dict[k][:, :, 0, 0]
cond_enc_dict = {
k
.replace("decoder.up_layers.", "decoder.condition_encoder.up_layers.")
.replace("decoder.encoder.", "decoder.condition_encoder.")
:
v
for k, v in ckpt["state_dict"].items() if (
k.startswith("decoder.up_layers.") or
k.startswith("decoder.encoder.")
)
}
quant_conv_dict = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("quant_conv.")}
post_quant_conv_dict = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("post_quant_conv.")}
vae.load_state_dict({**quant_conv_dict, **post_quant_conv_dict, **enc_dict, **dec_dict, **cond_enc_dict}) |
Great this is superb stuff! From my end, I think the PR is already in good shape. I think we need the following:
Let me know anything is unclear here :-) More than happy to help. |
@sayakpaul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fantastic to me!
Final TODOs:
- https://github.com/huggingface/diffusers/pull/3956/files#r1259162306
- Add nice model cards to https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2 and https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5 so that the community is aware of these.
|
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5) | ||
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since now we have https://huggingface.co/buxiangzhiren and we have made contact with the author, I think we can transfer the repositories. @patrickvonplaten could you please help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to transfer them once merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
@sayakpaul @patrickvonplaten |
Co-authored-by: Patrick von Platen <[email protected]>
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||
condition_kwargs = {} | ||
if isinstance(self.vae, AsymmetricAutoencoderKL): | ||
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we maybe compute the init_image_condition only here? Since it's not needed for the "normal" VAE?
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) | ||
init_image_condition = init_image.clone() | ||
init_image = self._encode_vae_image(init_image, generator=generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) | |
init_image_condition = init_image.clone() | |
init_image = self._encode_vae_image(init_image, generator=generator) |
I think this is only needed when the decoder is of type AsymmetricAutoencoderKL
- should we maybe add it further down, e.g. here: https://github.com/huggingface/diffusers/pull/3956/files#r1266533587
This way we can save an image encoding step
@@ -173,6 +173,56 @@ def test_output_pretrained(self): | |||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) | |||
|
|||
|
|||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there I think! Can we also add some tests to https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py to make sure the inpainting pipeline works as expected? :-)
@patrickvonplaten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool PR!
for l in range(len(self.layers)): | ||
layer = self.layers[l] | ||
x = layer(x) | ||
out[str(tuple(x.shape))] = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's happening here? We use the string of shapes as keys to store the encoded condition and use them to match with decoder blocks?
Is it possible that two layer outputs have same shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup
they are different
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't like this - if it's possible to configure the model in a way that output remain same shape between two layers we will have a problem here
cc @patrickvonplaten @sayakpaul let me know what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's possible to configure the model in a way that output remain same shape between two layers we will have a problem here
Valid concern. If there's a possibility of the underlying model to do this, then, yes, let's try to rejig this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree here as well - @cross-attention could we maybe do one more round of refactoring here:
- instead of creating a dict here we create a list of tuples that will be returned
- We also do the interpolation already here in this function instead of here: https://github.com/huggingface/diffusers/pull/3956/files#r1268048130
- then in the decoder, we make image and mask required forward args: https://github.com/huggingface/diffusers/pull/3956/files#r1268045503
- We keep the decoder code then much cleaner by just poping an element from the tuple
Would something like this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with optional mask and image we can also use AsymmetricAutoencoderKL
for text2image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask in that case would be a None right? And it seems like AsymmetricAutoencoderKL
already handles this case?
If so, it might be good to add a test to clarify that (potentially in a future PR).
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
Outdated
Show resolved
Hide resolved
@patrickvonplaten @yiyixuxu what's pending in this PR? From what I see this is the only that's pending: #3956 (comment). Let's try to ship this soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweet! Thanks for iterating.
|
||
self.gradient_checkpointing = False | ||
|
||
def forward(self, z, image=None, mask=None, latent_embeds=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly image
and mask
can't really be None no? Can we maybe force the user to pass both image and mask here -> this would make the code much easier to follow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they can be None
in this case decoder will work without mask/image condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case does it also yield improvements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the quality is almost equal in text2image setup according to the paper
if image is not None and mask is not None: | ||
sample_ = im_x[str(tuple(sample.shape))] | ||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") | ||
sample = sample * mask_ + sample_ * (1 - mask_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if image is not None and mask is not None: | |
sample_ = im_x[str(tuple(sample.shape))] | |
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") | |
sample = sample * mask_ + sample_ * (1 - mask_) |
Can we move this code to the condition encoder forward method instead? I think it would be better placed there
@patrickvonplaten |
Let's merge this PR for now as is, but if usage goes up of Asymmetric, I'd like to do a refactor here where we don't return a dict in the form <image_shape: image_out> but just a tuple of type <image_out> instead. By default we prefer the design of return non-mutable tuples in diffusers compared to dicts. But ok for now. |
Let's merge the Hun checkpoints too, @patrickvonplaten. |
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <[email protected]> * naming Co-authored-by: Sayak Paul <[email protected]> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <[email protected]> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <[email protected]> * naming Co-authored-by: Sayak Paul <[email protected]> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <[email protected]> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <[email protected]> * naming Co-authored-by: Sayak Paul <[email protected]> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <[email protected]> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <[email protected]> * naming Co-authored-by: Sayak Paul <[email protected]> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <[email protected]> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <[email protected]> * naming Co-authored-by: Sayak Paul <[email protected]> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <[email protected]> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
Added AsymmetricAutoencoderKL for Stable Diffusion Inpainting
Added
AsymmetricAutoencoderKL
model from Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632Added its support in
StableDiffusionInpaintPipeline
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten @sayakpaul