|
117 | 117 | "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
118 | 118 | "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
119 | 119 | "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
| 120 | + "sana": [ |
| 121 | + "blocks.0.cross_attn.q_linear.weight", |
| 122 | + "blocks.0.cross_attn.q_linear.bias", |
| 123 | + "blocks.0.cross_attn.kv_linear.weight", |
| 124 | + "blocks.0.cross_attn.kv_linear.bias", |
| 125 | + ], |
120 | 126 | "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
121 | 127 | "wan_vae": "decoder.middle.0.residual.0.gamma",
|
122 | 128 | }
|
|
178 | 184 | "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
179 | 185 | "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
180 | 186 | "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
| 187 | + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, |
181 | 188 | "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
182 | 189 | "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
183 | 190 | "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
@@ -669,6 +676,9 @@ def infer_diffusers_model_type(checkpoint):
|
669 | 676 | elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
670 | 677 | model_type = "lumina2"
|
671 | 678 |
|
| 679 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): |
| 680 | + model_type = "sana" |
| 681 | + |
672 | 682 | elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
673 | 683 | if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
674 | 684 | target_key = "model.diffusion_model.patch_embedding.weight"
|
@@ -2897,6 +2907,111 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
|
2897 | 2907 | return converted_state_dict
|
2898 | 2908 |
|
2899 | 2909 |
|
| 2910 | +def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): |
| 2911 | + converted_state_dict = {} |
| 2912 | + keys = list(checkpoint.keys()) |
| 2913 | + for k in keys: |
| 2914 | + if "model.diffusion_model." in k: |
| 2915 | + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) |
| 2916 | + |
| 2917 | + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 |
| 2918 | + |
| 2919 | + # Positional and patch embeddings. |
| 2920 | + checkpoint.pop("pos_embed") |
| 2921 | + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") |
| 2922 | + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") |
| 2923 | + |
| 2924 | + # Timestep embeddings. |
| 2925 | + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( |
| 2926 | + "t_embedder.mlp.0.weight" |
| 2927 | + ) |
| 2928 | + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") |
| 2929 | + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( |
| 2930 | + "t_embedder.mlp.2.weight" |
| 2931 | + ) |
| 2932 | + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") |
| 2933 | + converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") |
| 2934 | + converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") |
| 2935 | + |
| 2936 | + # Caption Projection. |
| 2937 | + checkpoint.pop("y_embedder.y_embedding") |
| 2938 | + converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") |
| 2939 | + converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") |
| 2940 | + converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") |
| 2941 | + converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") |
| 2942 | + converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") |
| 2943 | + |
| 2944 | + for i in range(num_layers): |
| 2945 | + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( |
| 2946 | + f"blocks.{i}.scale_shift_table" |
| 2947 | + ) |
| 2948 | + |
| 2949 | + # Self-Attention |
| 2950 | + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) |
| 2951 | + converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) |
| 2952 | + converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) |
| 2953 | + converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) |
| 2954 | + |
| 2955 | + # Output Projections |
| 2956 | + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( |
| 2957 | + f"blocks.{i}.attn.proj.weight" |
| 2958 | + ) |
| 2959 | + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( |
| 2960 | + f"blocks.{i}.attn.proj.bias" |
| 2961 | + ) |
| 2962 | + |
| 2963 | + # Cross-Attention |
| 2964 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( |
| 2965 | + f"blocks.{i}.cross_attn.q_linear.weight" |
| 2966 | + ) |
| 2967 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( |
| 2968 | + f"blocks.{i}.cross_attn.q_linear.bias" |
| 2969 | + ) |
| 2970 | + |
| 2971 | + linear_sample_k, linear_sample_v = torch.chunk( |
| 2972 | + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 |
| 2973 | + ) |
| 2974 | + linear_sample_k_bias, linear_sample_v_bias = torch.chunk( |
| 2975 | + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 |
| 2976 | + ) |
| 2977 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k |
| 2978 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v |
| 2979 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias |
| 2980 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias |
| 2981 | + |
| 2982 | + # Output Projections |
| 2983 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( |
| 2984 | + f"blocks.{i}.cross_attn.proj.weight" |
| 2985 | + ) |
| 2986 | + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( |
| 2987 | + f"blocks.{i}.cross_attn.proj.bias" |
| 2988 | + ) |
| 2989 | + |
| 2990 | + # MLP |
| 2991 | + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( |
| 2992 | + f"blocks.{i}.mlp.inverted_conv.conv.weight" |
| 2993 | + ) |
| 2994 | + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( |
| 2995 | + f"blocks.{i}.mlp.inverted_conv.conv.bias" |
| 2996 | + ) |
| 2997 | + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( |
| 2998 | + f"blocks.{i}.mlp.depth_conv.conv.weight" |
| 2999 | + ) |
| 3000 | + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( |
| 3001 | + f"blocks.{i}.mlp.depth_conv.conv.bias" |
| 3002 | + ) |
| 3003 | + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( |
| 3004 | + f"blocks.{i}.mlp.point_conv.conv.weight" |
| 3005 | + ) |
| 3006 | + |
| 3007 | + # Final layer |
| 3008 | + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") |
| 3009 | + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") |
| 3010 | + converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") |
| 3011 | + |
| 3012 | + return converted_state_dict |
| 3013 | + |
| 3014 | + |
2900 | 3015 | def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
2901 | 3016 | converted_state_dict = {}
|
2902 | 3017 |
|
|
0 commit comments