Skip to content

Commit 07f1fbb

Browse files
cross-attentionRuslan Vorovchenkosayakpaulpatrickvonplaten
authored
Asymmetric vqgan (huggingface#3956)
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <[email protected]> * naming Co-authored-by: Sayak Paul <[email protected]> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <[email protected]> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2551b73 commit 07f1fbb

File tree

11 files changed

+1119
-20
lines changed

11 files changed

+1119
-20
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@
166166
title: VQModel
167167
- local: api/models/autoencoderkl
168168
title: AutoencoderKL
169+
- local: api/models/asymmetricautoencoderkl
170+
title: AsymmetricAutoencoderKL
169171
- local: api/models/transformer2d
170172
title: Transformer2D
171173
- local: api/models/transformer_temporal
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# AsymmetricAutoencoderKL
2+
3+
Improved larger variational autoencoder (VAE) model with KL loss for inpainting task: [Designing a Better Asymmetric VQGAN for StableDiffusion](https://arxiv.org/abs/2306.04632) by Zixin Zhu, Xuelu Feng, Dongdong Chen, Jianmin Bao, Le Wang, Yinpeng Chen, Lu Yuan, Gang Hua.
4+
5+
The abstract from the paper is:
6+
7+
*StableDiffusion is a revolutionary text-to-image generator that is causing a stir in the world of image generation and editing. Unlike traditional methods that learn a diffusion model in pixel space, StableDiffusion learns a diffusion model in the latent space via a VQGAN, ensuring both efficiency and quality. It not only supports image generation tasks, but also enables image editing for real images, such as image inpainting and local editing. However, we have observed that the vanilla VQGAN used in StableDiffusion leads to significant information loss, causing distortion artifacts even in non-edited image regions. To this end, we propose a new asymmetric VQGAN with two simple designs. Firstly, in addition to the input from the encoder, the decoder contains a conditional branch that incorporates information from task-specific priors, such as the unmasked image region in inpainting. Secondly, the decoder is much heavier than the encoder, allowing for more detailed recovery while only slightly increasing the total inference cost. The training cost of our asymmetric VQGAN is cheap, and we only need to retrain a new asymmetric decoder while keeping the vanilla VQGAN encoder and StableDiffusion unchanged. Our asymmetric VQGAN can be widely used in StableDiffusion-based inpainting and local editing methods. Extensive experiments demonstrate that it can significantly improve the inpainting and editing performance, while maintaining the original text-to-image capability. The code is available at https://github.com/buxiangzhiren/Asymmetric_VQGAN*
8+
9+
Evaluation results can be found in section 4.1 of the original paper.
10+
11+
## Available checkpoints
12+
13+
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5)
14+
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2)
15+
16+
## Example Usage
17+
18+
```python
19+
from io import BytesIO
20+
from PIL import Image
21+
import requests
22+
from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
23+
24+
25+
def download_image(url: str) -> Image.Image:
26+
response = requests.get(url)
27+
return Image.open(BytesIO(response.content)).convert("RGB")
28+
29+
30+
prompt = "a photo of a person"
31+
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
32+
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
33+
34+
image = download_image(img_url).resize((256, 256))
35+
mask_image = download_image(mask_url).resize((256, 256))
36+
37+
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
38+
pipe.vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
39+
pipe.to("cuda")
40+
41+
image = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]
42+
image.save("image.jpeg")
43+
```
44+
45+
## AsymmetricAutoencoderKL
46+
47+
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
48+
49+
## AutoencoderKLOutput
50+
51+
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
52+
53+
## DecoderOutput
54+
55+
[[autodoc]] models.vae.DecoderOutput
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import argparse
2+
import time
3+
from pathlib import Path
4+
from typing import Any, Dict, Literal
5+
6+
import torch
7+
8+
from diffusers import AsymmetricAutoencoderKL
9+
10+
11+
ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
12+
"in_channels": 3,
13+
"out_channels": 3,
14+
"down_block_types": [
15+
"DownEncoderBlock2D",
16+
"DownEncoderBlock2D",
17+
"DownEncoderBlock2D",
18+
"DownEncoderBlock2D",
19+
],
20+
"down_block_out_channels": [128, 256, 512, 512],
21+
"layers_per_down_block": 2,
22+
"up_block_types": [
23+
"UpDecoderBlock2D",
24+
"UpDecoderBlock2D",
25+
"UpDecoderBlock2D",
26+
"UpDecoderBlock2D",
27+
],
28+
"up_block_out_channels": [192, 384, 768, 768],
29+
"layers_per_up_block": 3,
30+
"act_fn": "silu",
31+
"latent_channels": 4,
32+
"norm_num_groups": 32,
33+
"sample_size": 256,
34+
"scaling_factor": 0.18215,
35+
}
36+
37+
ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
38+
"in_channels": 3,
39+
"out_channels": 3,
40+
"down_block_types": [
41+
"DownEncoderBlock2D",
42+
"DownEncoderBlock2D",
43+
"DownEncoderBlock2D",
44+
"DownEncoderBlock2D",
45+
],
46+
"down_block_out_channels": [128, 256, 512, 512],
47+
"layers_per_down_block": 2,
48+
"up_block_types": [
49+
"UpDecoderBlock2D",
50+
"UpDecoderBlock2D",
51+
"UpDecoderBlock2D",
52+
"UpDecoderBlock2D",
53+
],
54+
"up_block_out_channels": [256, 512, 1024, 1024],
55+
"layers_per_up_block": 5,
56+
"act_fn": "silu",
57+
"latent_channels": 4,
58+
"norm_num_groups": 32,
59+
"sample_size": 256,
60+
"scaling_factor": 0.18215,
61+
}
62+
63+
64+
def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
65+
converted_state_dict = {}
66+
for k, v in original_state_dict.items():
67+
if k.startswith("encoder."):
68+
converted_state_dict[
69+
k.replace("encoder.down.", "encoder.down_blocks.")
70+
.replace("encoder.mid.", "encoder.mid_block.")
71+
.replace("encoder.norm_out.", "encoder.conv_norm_out.")
72+
.replace(".downsample.", ".downsamplers.0.")
73+
.replace(".nin_shortcut.", ".conv_shortcut.")
74+
.replace(".block.", ".resnets.")
75+
.replace(".block_1.", ".resnets.0.")
76+
.replace(".block_2.", ".resnets.1.")
77+
.replace(".attn_1.k.", ".attentions.0.to_k.")
78+
.replace(".attn_1.q.", ".attentions.0.to_q.")
79+
.replace(".attn_1.v.", ".attentions.0.to_v.")
80+
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
81+
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
82+
] = v
83+
elif k.startswith("decoder.") and "up_layers" not in k:
84+
converted_state_dict[
85+
k.replace("decoder.encoder.", "decoder.condition_encoder.")
86+
.replace(".norm_out.", ".conv_norm_out.")
87+
.replace(".up.0.", ".up_blocks.3.")
88+
.replace(".up.1.", ".up_blocks.2.")
89+
.replace(".up.2.", ".up_blocks.1.")
90+
.replace(".up.3.", ".up_blocks.0.")
91+
.replace(".block.", ".resnets.")
92+
.replace("mid", "mid_block")
93+
.replace(".0.upsample.", ".0.upsamplers.0.")
94+
.replace(".1.upsample.", ".1.upsamplers.0.")
95+
.replace(".2.upsample.", ".2.upsamplers.0.")
96+
.replace(".nin_shortcut.", ".conv_shortcut.")
97+
.replace(".block_1.", ".resnets.0.")
98+
.replace(".block_2.", ".resnets.1.")
99+
.replace(".attn_1.k.", ".attentions.0.to_k.")
100+
.replace(".attn_1.q.", ".attentions.0.to_q.")
101+
.replace(".attn_1.v.", ".attentions.0.to_v.")
102+
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
103+
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
104+
] = v
105+
elif k.startswith("quant_conv."):
106+
converted_state_dict[k] = v
107+
elif k.startswith("post_quant_conv."):
108+
converted_state_dict[k] = v
109+
else:
110+
print(f" skipping key `{k}`")
111+
# fix weights shape
112+
for k, v in converted_state_dict.items():
113+
if (
114+
(k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
115+
and k.endswith("weight")
116+
and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
117+
):
118+
converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
119+
120+
return converted_state_dict
121+
122+
123+
def get_asymmetric_autoencoder_kl_from_original_checkpoint(
124+
scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
125+
) -> AsymmetricAutoencoderKL:
126+
print("Loading original state_dict")
127+
original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
128+
original_state_dict = original_state_dict["state_dict"]
129+
print("Converting state_dict")
130+
converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
131+
kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
132+
print("Initializing AsymmetricAutoencoderKL model")
133+
asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
134+
print("Loading weight from converted state_dict")
135+
asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
136+
asymmetric_autoencoder_kl.eval()
137+
print("AsymmetricAutoencoderKL successfully initialized")
138+
return asymmetric_autoencoder_kl
139+
140+
141+
if __name__ == "__main__":
142+
start = time.time()
143+
parser = argparse.ArgumentParser()
144+
parser.add_argument(
145+
"--scale",
146+
default=None,
147+
type=str,
148+
required=True,
149+
help="Asymmetric VQGAN scale: `1.5` or `2`",
150+
)
151+
parser.add_argument(
152+
"--original_checkpoint_path",
153+
default=None,
154+
type=str,
155+
required=True,
156+
help="Path to the original Asymmetric VQGAN checkpoint",
157+
)
158+
parser.add_argument(
159+
"--output_path",
160+
default=None,
161+
type=str,
162+
required=True,
163+
help="Path to save pretrained AsymmetricAutoencoderKL model",
164+
)
165+
parser.add_argument(
166+
"--map_location",
167+
default="cpu",
168+
type=str,
169+
required=False,
170+
help="The device passed to `map_location` when loading the checkpoint",
171+
)
172+
args = parser.parse_args()
173+
174+
assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
175+
assert Path(args.original_checkpoint_path).is_file()
176+
177+
asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
178+
scale=args.scale,
179+
original_checkpoint_path=args.original_checkpoint_path,
180+
map_location=torch.device(args.map_location),
181+
)
182+
print("Saving pretrained AsymmetricAutoencoderKL")
183+
asymmetric_autoencoder_kl.save_pretrained(args.output_path)
184+
print(f"Done in {time.time() - start:.2f} seconds")

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .utils.dummy_pt_objects import * # noqa F403
3737
else:
3838
from .models import (
39+
AsymmetricAutoencoderKL,
3940
AutoencoderKL,
4041
ControlNetModel,
4142
ModelMixin,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if is_torch_available():
1919
from .adapter import MultiAdapter, T2IAdapter
20+
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2021
from .autoencoder_kl import AutoencoderKL
2122
from .controlnet import ControlNetModel
2223
from .dual_transformer_2d import DualTransformer2DModel

0 commit comments

Comments
 (0)