|
79 | 79 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", |
80 | 80 | "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", |
81 | 81 | "animatediff_rgb": "controlnet_cond_embedding.weight", |
82 | | - "flux": "double_blocks.0.img_attn.norm.key_norm.scale", |
| 82 | + "flux": [ |
| 83 | + "double_blocks.0.img_attn.norm.key_norm.scale", |
| 84 | + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", |
| 85 | + ], |
83 | 86 | } |
84 | 87 |
|
85 | 88 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { |
|
258 | 261 | "timestep_spacing": "leading", |
259 | 262 | } |
260 | 263 |
|
261 | | -LDM_VAE_KEY = "first_stage_model." |
| 264 | +LDM_VAE_KEYS = ["first_stage_model.", "vae."] |
262 | 265 | LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 |
263 | 266 | PLAYGROUND_VAE_SCALING_FACTOR = 0.5 |
264 | 267 | LDM_UNET_KEY = "model.diffusion_model." |
|
267 | 270 | "cond_stage_model.transformer.", |
268 | 271 | "conditioner.embedders.0.transformer.", |
269 | 272 | ] |
270 | | -OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." |
271 | 273 | LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 |
272 | 274 | SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"] |
273 | 275 |
|
@@ -523,8 +525,10 @@ def infer_diffusers_model_type(checkpoint): |
523 | 525 | else: |
524 | 526 | model_type = "animatediff_v3" |
525 | 527 |
|
526 | | - elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: |
527 | | - if "guidance_in.in_layer.bias" in checkpoint: |
| 528 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): |
| 529 | + if any( |
| 530 | + g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] |
| 531 | + ): |
528 | 532 | model_type = "flux-dev" |
529 | 533 | else: |
530 | 534 | model_type = "flux-schnell" |
@@ -1183,7 +1187,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config): |
1183 | 1187 | # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys |
1184 | 1188 | vae_state_dict = {} |
1185 | 1189 | keys = list(checkpoint.keys()) |
1186 | | - vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" |
| 1190 | + vae_key = "" |
| 1191 | + for ldm_vae_key in LDM_VAE_KEYS: |
| 1192 | + if any(k.startswith(ldm_vae_key) for k in keys): |
| 1193 | + vae_key = ldm_vae_key |
| 1194 | + |
1187 | 1195 | for key in keys: |
1188 | 1196 | if key.startswith(vae_key): |
1189 | 1197 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) |
@@ -1896,6 +1904,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): |
1896 | 1904 |
|
1897 | 1905 | def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
1898 | 1906 | converted_state_dict = {} |
| 1907 | + keys = list(checkpoint.keys()) |
| 1908 | + for k in keys: |
| 1909 | + if "model.diffusion_model." in k: |
| 1910 | + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) |
1899 | 1911 |
|
1900 | 1912 | num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 |
1901 | 1913 | num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 |
|
0 commit comments