|
| 1 | +import argparse |
| 2 | +import io |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +import requests |
| 7 | +from diffusers import AutoencoderKL |
| 8 | +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( |
| 9 | + assign_to_checkpoint, |
| 10 | + conv_attn_to_linear, |
| 11 | + create_vae_diffusers_config, |
| 12 | + renew_vae_attention_paths, |
| 13 | + renew_vae_resnet_paths, |
| 14 | +) |
| 15 | +from omegaconf import OmegaConf |
| 16 | + |
| 17 | + |
| 18 | +def custom_convert_ldm_vae_checkpoint(checkpoint, config): |
| 19 | + vae_state_dict = checkpoint |
| 20 | + |
| 21 | + new_checkpoint = {} |
| 22 | + |
| 23 | + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] |
| 24 | + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] |
| 25 | + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] |
| 26 | + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] |
| 27 | + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] |
| 28 | + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] |
| 29 | + |
| 30 | + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] |
| 31 | + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] |
| 32 | + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] |
| 33 | + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] |
| 34 | + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] |
| 35 | + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] |
| 36 | + |
| 37 | + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] |
| 38 | + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] |
| 39 | + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] |
| 40 | + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] |
| 41 | + |
| 42 | + # Retrieves the keys for the encoder down blocks only |
| 43 | + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) |
| 44 | + down_blocks = { |
| 45 | + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) |
| 46 | + } |
| 47 | + |
| 48 | + # Retrieves the keys for the decoder up blocks only |
| 49 | + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) |
| 50 | + up_blocks = { |
| 51 | + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) |
| 52 | + } |
| 53 | + |
| 54 | + for i in range(num_down_blocks): |
| 55 | + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] |
| 56 | + |
| 57 | + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: |
| 58 | + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( |
| 59 | + f"encoder.down.{i}.downsample.conv.weight" |
| 60 | + ) |
| 61 | + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( |
| 62 | + f"encoder.down.{i}.downsample.conv.bias" |
| 63 | + ) |
| 64 | + |
| 65 | + paths = renew_vae_resnet_paths(resnets) |
| 66 | + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} |
| 67 | + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) |
| 68 | + |
| 69 | + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] |
| 70 | + num_mid_res_blocks = 2 |
| 71 | + for i in range(1, num_mid_res_blocks + 1): |
| 72 | + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] |
| 73 | + |
| 74 | + paths = renew_vae_resnet_paths(resnets) |
| 75 | + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} |
| 76 | + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) |
| 77 | + |
| 78 | + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] |
| 79 | + paths = renew_vae_attention_paths(mid_attentions) |
| 80 | + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} |
| 81 | + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) |
| 82 | + conv_attn_to_linear(new_checkpoint) |
| 83 | + |
| 84 | + for i in range(num_up_blocks): |
| 85 | + block_id = num_up_blocks - 1 - i |
| 86 | + resnets = [ |
| 87 | + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key |
| 88 | + ] |
| 89 | + |
| 90 | + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: |
| 91 | + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ |
| 92 | + f"decoder.up.{block_id}.upsample.conv.weight" |
| 93 | + ] |
| 94 | + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ |
| 95 | + f"decoder.up.{block_id}.upsample.conv.bias" |
| 96 | + ] |
| 97 | + |
| 98 | + paths = renew_vae_resnet_paths(resnets) |
| 99 | + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} |
| 100 | + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) |
| 101 | + |
| 102 | + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] |
| 103 | + num_mid_res_blocks = 2 |
| 104 | + for i in range(1, num_mid_res_blocks + 1): |
| 105 | + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] |
| 106 | + |
| 107 | + paths = renew_vae_resnet_paths(resnets) |
| 108 | + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} |
| 109 | + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) |
| 110 | + |
| 111 | + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] |
| 112 | + paths = renew_vae_attention_paths(mid_attentions) |
| 113 | + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} |
| 114 | + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) |
| 115 | + conv_attn_to_linear(new_checkpoint) |
| 116 | + return new_checkpoint |
| 117 | + |
| 118 | + |
| 119 | +def vae_pt_to_vae_diffuser( |
| 120 | + checkpoint_path: str, |
| 121 | + output_path: str, |
| 122 | +): |
| 123 | + # Only support V1 |
| 124 | + r = requests.get( |
| 125 | + " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" |
| 126 | + ) |
| 127 | + io_obj = io.BytesIO(r.content) |
| 128 | + |
| 129 | + original_config = OmegaConf.load(io_obj) |
| 130 | + image_size = 512 |
| 131 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 132 | + checkpoint = torch.load(checkpoint_path, map_location=device) |
| 133 | + |
| 134 | + # Convert the VAE model. |
| 135 | + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) |
| 136 | + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config) |
| 137 | + |
| 138 | + vae = AutoencoderKL(**vae_config) |
| 139 | + vae.load_state_dict(converted_vae_checkpoint) |
| 140 | + vae.save_pretrained(output_path) |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == "__main__": |
| 144 | + parser = argparse.ArgumentParser() |
| 145 | + |
| 146 | + parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") |
| 147 | + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") |
| 148 | + |
| 149 | + vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) |
0 commit comments