Skip to content

Commit 717a956

Browse files
chavinlopatrickvonplatenchavinlo
authored
Create convert_vae_pt_to_diffusers.py (huggingface#2215)
* Create convert_vae_pt_to_diffusers.py Just a simple script to convert VAE.pt files to diffusers format Tested with: https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/VAEs/orangemix.vae.pt * Update convert_vae_pt_to_diffusers.py Forgot to add the function call * make style --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: chavinlo <[email protected]>
1 parent d43972a commit 717a956

File tree

4 files changed

+158
-9
lines changed

4 files changed

+158
-9
lines changed

examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,9 @@ def main():
752752
# Let's make sure we don't update any embedding weights besides the newly added token
753753
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
754754
with torch.no_grad():
755-
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
756-
index_no_updates
757-
] = orig_embeds_params[index_no_updates]
755+
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
756+
orig_embeds_params[index_no_updates]
757+
)
758758

759759
# Checks if the accelerator has performed an optimization step behind the scenes
760760
if accelerator.sync_gradients:

examples/textual_inversion/textual_inversion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,9 +749,9 @@ def main():
749749
# Let's make sure we don't update any embedding weights besides the newly added token
750750
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
751751
with torch.no_grad():
752-
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
753-
index_no_updates
754-
] = orig_embeds_params[index_no_updates]
752+
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
753+
orig_embeds_params[index_no_updates]
754+
)
755755

756756
# Checks if the accelerator has performed an optimization step behind the scenes
757757
if accelerator.sync_gradients:
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import argparse
2+
import io
3+
4+
import torch
5+
6+
import requests
7+
from diffusers import AutoencoderKL
8+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
9+
assign_to_checkpoint,
10+
conv_attn_to_linear,
11+
create_vae_diffusers_config,
12+
renew_vae_attention_paths,
13+
renew_vae_resnet_paths,
14+
)
15+
from omegaconf import OmegaConf
16+
17+
18+
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
19+
vae_state_dict = checkpoint
20+
21+
new_checkpoint = {}
22+
23+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
24+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
25+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
26+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
27+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
28+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
29+
30+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
31+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
32+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
33+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
34+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
35+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
36+
37+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
38+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
39+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
40+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
41+
42+
# Retrieves the keys for the encoder down blocks only
43+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
44+
down_blocks = {
45+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
46+
}
47+
48+
# Retrieves the keys for the decoder up blocks only
49+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
50+
up_blocks = {
51+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
52+
}
53+
54+
for i in range(num_down_blocks):
55+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
56+
57+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
58+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
59+
f"encoder.down.{i}.downsample.conv.weight"
60+
)
61+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
62+
f"encoder.down.{i}.downsample.conv.bias"
63+
)
64+
65+
paths = renew_vae_resnet_paths(resnets)
66+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
67+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
68+
69+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
70+
num_mid_res_blocks = 2
71+
for i in range(1, num_mid_res_blocks + 1):
72+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
73+
74+
paths = renew_vae_resnet_paths(resnets)
75+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
76+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
77+
78+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
79+
paths = renew_vae_attention_paths(mid_attentions)
80+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
81+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
82+
conv_attn_to_linear(new_checkpoint)
83+
84+
for i in range(num_up_blocks):
85+
block_id = num_up_blocks - 1 - i
86+
resnets = [
87+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
88+
]
89+
90+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
91+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
92+
f"decoder.up.{block_id}.upsample.conv.weight"
93+
]
94+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
95+
f"decoder.up.{block_id}.upsample.conv.bias"
96+
]
97+
98+
paths = renew_vae_resnet_paths(resnets)
99+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
100+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
101+
102+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
103+
num_mid_res_blocks = 2
104+
for i in range(1, num_mid_res_blocks + 1):
105+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
106+
107+
paths = renew_vae_resnet_paths(resnets)
108+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
109+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
110+
111+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
112+
paths = renew_vae_attention_paths(mid_attentions)
113+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
114+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
115+
conv_attn_to_linear(new_checkpoint)
116+
return new_checkpoint
117+
118+
119+
def vae_pt_to_vae_diffuser(
120+
checkpoint_path: str,
121+
output_path: str,
122+
):
123+
# Only support V1
124+
r = requests.get(
125+
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
126+
)
127+
io_obj = io.BytesIO(r.content)
128+
129+
original_config = OmegaConf.load(io_obj)
130+
image_size = 512
131+
device = "cuda" if torch.cuda.is_available() else "cpu"
132+
checkpoint = torch.load(checkpoint_path, map_location=device)
133+
134+
# Convert the VAE model.
135+
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
136+
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config)
137+
138+
vae = AutoencoderKL(**vae_config)
139+
vae.load_state_dict(converted_vae_checkpoint)
140+
vae.save_pretrained(output_path)
141+
142+
143+
if __name__ == "__main__":
144+
parser = argparse.ArgumentParser()
145+
146+
parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
147+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
148+
149+
vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)

src/diffusers/utils/hub_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def create_model_card(args, model_name):
112112
learning_rate=args.learning_rate,
113113
train_batch_size=args.train_batch_size,
114114
eval_batch_size=args.eval_batch_size,
115-
gradient_accumulation_steps=args.gradient_accumulation_steps
116-
if hasattr(args, "gradient_accumulation_steps")
117-
else None,
115+
gradient_accumulation_steps=(
116+
args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
117+
),
118118
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
119119
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
120120
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,

0 commit comments

Comments
 (0)