Skip to content

Commit 0703ce8

Browse files
ishan-modiDN6
andauthored
[Single File] Add single file loading for SANA Transformer (#10947)
* added support for from_single_file * added diffusers mapping script * added testcase * bug fix * updated tests * corrected code quality * corrected code quality --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent f5edaa7 commit 0703ce8

File tree

4 files changed

+183
-2
lines changed

4 files changed

+183
-2
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
convert_ltx_vae_checkpoint_to_diffusers,
3838
convert_lumina2_to_diffusers,
3939
convert_mochi_transformer_checkpoint_to_diffusers,
40+
convert_sana_transformer_to_diffusers,
4041
convert_sd3_transformer_checkpoint_to_diffusers,
4142
convert_stable_cascade_unet_single_file_to_diffusers,
4243
convert_wan_transformer_to_diffusers,
@@ -119,6 +120,10 @@
119120
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
120121
"default_subfolder": "transformer",
121122
},
123+
"SanaTransformer2DModel": {
124+
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
125+
"default_subfolder": "transformer",
126+
},
122127
"WanTransformer3DModel": {
123128
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
124129
"default_subfolder": "transformer",

src/diffusers/loaders/single_file_utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@
117117
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
118118
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
119119
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
120+
"sana": [
121+
"blocks.0.cross_attn.q_linear.weight",
122+
"blocks.0.cross_attn.q_linear.bias",
123+
"blocks.0.cross_attn.kv_linear.weight",
124+
"blocks.0.cross_attn.kv_linear.bias",
125+
],
120126
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
121127
"wan_vae": "decoder.middle.0.residual.0.gamma",
122128
}
@@ -178,6 +184,7 @@
178184
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
179185
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
180186
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
187+
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
181188
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
182189
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
183190
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
@@ -669,6 +676,9 @@ def infer_diffusers_model_type(checkpoint):
669676
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
670677
model_type = "lumina2"
671678

679+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
680+
model_type = "sana"
681+
672682
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
673683
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
674684
target_key = "model.diffusion_model.patch_embedding.weight"
@@ -2897,6 +2907,111 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
28972907
return converted_state_dict
28982908

28992909

2910+
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
2911+
converted_state_dict = {}
2912+
keys = list(checkpoint.keys())
2913+
for k in keys:
2914+
if "model.diffusion_model." in k:
2915+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2916+
2917+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
2918+
2919+
# Positional and patch embeddings.
2920+
checkpoint.pop("pos_embed")
2921+
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2922+
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2923+
2924+
# Timestep embeddings.
2925+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
2926+
"t_embedder.mlp.0.weight"
2927+
)
2928+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2929+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
2930+
"t_embedder.mlp.2.weight"
2931+
)
2932+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2933+
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
2934+
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
2935+
2936+
# Caption Projection.
2937+
checkpoint.pop("y_embedder.y_embedding")
2938+
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
2939+
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
2940+
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
2941+
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
2942+
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
2943+
2944+
for i in range(num_layers):
2945+
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
2946+
f"blocks.{i}.scale_shift_table"
2947+
)
2948+
2949+
# Self-Attention
2950+
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
2951+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
2952+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
2953+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
2954+
2955+
# Output Projections
2956+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
2957+
f"blocks.{i}.attn.proj.weight"
2958+
)
2959+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
2960+
f"blocks.{i}.attn.proj.bias"
2961+
)
2962+
2963+
# Cross-Attention
2964+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
2965+
f"blocks.{i}.cross_attn.q_linear.weight"
2966+
)
2967+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
2968+
f"blocks.{i}.cross_attn.q_linear.bias"
2969+
)
2970+
2971+
linear_sample_k, linear_sample_v = torch.chunk(
2972+
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
2973+
)
2974+
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
2975+
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
2976+
)
2977+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
2978+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
2979+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
2980+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
2981+
2982+
# Output Projections
2983+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
2984+
f"blocks.{i}.cross_attn.proj.weight"
2985+
)
2986+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
2987+
f"blocks.{i}.cross_attn.proj.bias"
2988+
)
2989+
2990+
# MLP
2991+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
2992+
f"blocks.{i}.mlp.inverted_conv.conv.weight"
2993+
)
2994+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
2995+
f"blocks.{i}.mlp.inverted_conv.conv.bias"
2996+
)
2997+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
2998+
f"blocks.{i}.mlp.depth_conv.conv.weight"
2999+
)
3000+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
3001+
f"blocks.{i}.mlp.depth_conv.conv.bias"
3002+
)
3003+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
3004+
f"blocks.{i}.mlp.point_conv.conv.weight"
3005+
)
3006+
3007+
# Final layer
3008+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
3009+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
3010+
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
3011+
3012+
return converted_state_dict
3013+
3014+
29003015
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
29013016
converted_state_dict = {}
29023017

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import nn
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
21-
from ...loaders import PeftAdapterMixin
21+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2222
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2323
from ..attention_processor import (
2424
Attention,
@@ -195,7 +195,7 @@ def forward(
195195
return hidden_states
196196

197197

198-
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
198+
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
199199
r"""
200200
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
201201
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import gc
2+
import unittest
3+
4+
import torch
5+
6+
from diffusers import (
7+
SanaTransformer2DModel,
8+
)
9+
from diffusers.utils.testing_utils import (
10+
backend_empty_cache,
11+
enable_full_determinism,
12+
require_torch_accelerator,
13+
torch_device,
14+
)
15+
16+
17+
enable_full_determinism()
18+
19+
20+
@require_torch_accelerator
21+
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
22+
model_class = SanaTransformer2DModel
23+
ckpt_path = (
24+
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
25+
)
26+
alternate_keys_ckpt_paths = [
27+
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
28+
]
29+
30+
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
31+
32+
def setUp(self):
33+
super().setUp()
34+
gc.collect()
35+
backend_empty_cache(torch_device)
36+
37+
def tearDown(self):
38+
super().tearDown()
39+
gc.collect()
40+
backend_empty_cache(torch_device)
41+
42+
def test_single_file_components(self):
43+
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
44+
model_single_file = self.model_class.from_single_file(self.ckpt_path)
45+
46+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
47+
for param_name, param_value in model_single_file.config.items():
48+
if param_name in PARAMS_TO_IGNORE:
49+
continue
50+
assert (
51+
model.config[param_name] == param_value
52+
), f"{param_name} differs between single file loading and pretrained loading"
53+
54+
def test_checkpoint_loading(self):
55+
for ckpt_path in self.alternate_keys_ckpt_paths:
56+
torch.cuda.empty_cache()
57+
model = self.model_class.from_single_file(ckpt_path)
58+
59+
del model
60+
gc.collect()
61+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)