Skip to content

Commit 715a7da

Browse files
authored
add sd3 conversion script (huggingface#8702)
add conversion script
1 parent 14d224d commit 715a7da

File tree

1 file changed

+248
-0
lines changed

1 file changed

+248
-0
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
import torch
6+
from accelerate import init_empty_weights
7+
8+
from diffusers import AutoencoderKL, SD3Transformer2DModel
9+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
10+
from diffusers.models.modeling_utils import load_model_dict_into_meta
11+
from diffusers.utils.import_utils import is_accelerate_available
12+
13+
14+
CTX = init_empty_weights if is_accelerate_available else nullcontext
15+
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument("--checkpoint_path", type=str)
18+
parser.add_argument("--output_path", type=str)
19+
parser.add_argument("--dtype", type=str, default="fp16")
20+
21+
args = parser.parse_args()
22+
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
23+
24+
25+
def load_original_checkpoint(ckpt_path):
26+
original_state_dict = safetensors.torch.load_file(ckpt_path)
27+
keys = list(original_state_dict.keys())
28+
for k in keys:
29+
if "model.diffusion_model." in k:
30+
original_state_dict[k.replace("model.diffusion_model.", "")] = original_state_dict.pop(k)
31+
32+
return original_state_dict
33+
34+
35+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
36+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
37+
def swap_scale_shift(weight, dim):
38+
shift, scale = weight.chunk(2, dim=0)
39+
new_weight = torch.cat([scale, shift], dim=0)
40+
return new_weight
41+
42+
43+
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
44+
converted_state_dict = {}
45+
46+
# Positional and patch embeddings.
47+
converted_state_dict["pos_embed.pos_embed"] = original_state_dict.pop("pos_embed")
48+
converted_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
49+
converted_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
50+
51+
# Timestep embeddings.
52+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
53+
"t_embedder.mlp.0.weight"
54+
)
55+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
56+
"t_embedder.mlp.0.bias"
57+
)
58+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
59+
"t_embedder.mlp.2.weight"
60+
)
61+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
62+
"t_embedder.mlp.2.bias"
63+
)
64+
65+
# Context projections.
66+
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("context_embedder.weight")
67+
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("context_embedder.bias")
68+
69+
# Pooled context projection.
70+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
71+
"y_embedder.mlp.0.weight"
72+
)
73+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
74+
"y_embedder.mlp.0.bias"
75+
)
76+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
77+
"y_embedder.mlp.2.weight"
78+
)
79+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
80+
"y_embedder.mlp.2.bias"
81+
)
82+
83+
# Transformer blocks 🎸.
84+
for i in range(num_layers):
85+
# Q, K, V
86+
sample_q, sample_k, sample_v = torch.chunk(
87+
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
88+
)
89+
context_q, context_k, context_v = torch.chunk(
90+
original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
91+
)
92+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
93+
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
94+
)
95+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
96+
original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
97+
)
98+
99+
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
100+
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
101+
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
102+
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
103+
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
104+
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
105+
106+
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
107+
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
108+
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
109+
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
110+
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
111+
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
112+
113+
# output projections.
114+
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
115+
f"joint_blocks.{i}.x_block.attn.proj.weight"
116+
)
117+
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = original_state_dict.pop(
118+
f"joint_blocks.{i}.x_block.attn.proj.bias"
119+
)
120+
if not (i == num_layers - 1):
121+
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = original_state_dict.pop(
122+
f"joint_blocks.{i}.context_block.attn.proj.weight"
123+
)
124+
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = original_state_dict.pop(
125+
f"joint_blocks.{i}.context_block.attn.proj.bias"
126+
)
127+
128+
# norms.
129+
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
130+
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
131+
)
132+
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = original_state_dict.pop(
133+
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
134+
)
135+
if not (i == num_layers - 1):
136+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = original_state_dict.pop(
137+
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
138+
)
139+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = original_state_dict.pop(
140+
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
141+
)
142+
else:
143+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
144+
original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
145+
dim=caption_projection_dim,
146+
)
147+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
148+
original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
149+
dim=caption_projection_dim,
150+
)
151+
152+
# ffs.
153+
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = original_state_dict.pop(
154+
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
155+
)
156+
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = original_state_dict.pop(
157+
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
158+
)
159+
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = original_state_dict.pop(
160+
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
161+
)
162+
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = original_state_dict.pop(
163+
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
164+
)
165+
if not (i == num_layers - 1):
166+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = original_state_dict.pop(
167+
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
168+
)
169+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = original_state_dict.pop(
170+
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
171+
)
172+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = original_state_dict.pop(
173+
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
174+
)
175+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = original_state_dict.pop(
176+
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
177+
)
178+
179+
# Final blocks.
180+
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
181+
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
182+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
183+
original_state_dict.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
184+
)
185+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
186+
original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
187+
)
188+
189+
return converted_state_dict
190+
191+
192+
def is_vae_in_checkpoint(original_state_dict):
193+
return ("first_stage_model.decoder.conv_in.weight" in original_state_dict) and (
194+
"first_stage_model.encoder.conv_in.weight" in original_state_dict
195+
)
196+
197+
198+
def main(args):
199+
original_ckpt = load_original_checkpoint(args.checkpoint_path)
200+
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
201+
caption_projection_dim = 1536
202+
203+
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
204+
original_ckpt, num_layers, caption_projection_dim
205+
)
206+
207+
with CTX():
208+
transformer = SD3Transformer2DModel(
209+
sample_size=64,
210+
patch_size=2,
211+
in_channels=16,
212+
joint_attention_dim=4096,
213+
num_layers=num_layers,
214+
caption_projection_dim=caption_projection_dim,
215+
num_attention_heads=24,
216+
pos_embed_max_size=192,
217+
)
218+
if is_accelerate_available():
219+
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
220+
else:
221+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
222+
223+
print("Saving SD3 Transformer in Diffusers format.")
224+
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
225+
226+
if is_vae_in_checkpoint(original_ckpt):
227+
with CTX():
228+
vae = AutoencoderKL.from_config(
229+
"stabilityai/stable-diffusion-xl-base-1.0",
230+
subfolder="vae",
231+
latent_channels=16,
232+
use_post_quant_conv=False,
233+
use_quant_conv=False,
234+
scaling_factor=1.5305,
235+
shift_factor=0.0609,
236+
)
237+
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
238+
if is_accelerate_available():
239+
load_model_dict_into_meta(vae, converted_vae_state_dict)
240+
else:
241+
vae.load_state_dict(converted_vae_state_dict, strict=True)
242+
243+
print("Saving SD3 Autoencoder in Diffusers format.")
244+
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
245+
246+
247+
if __name__ == "__main__":
248+
main(args)

0 commit comments

Comments
 (0)