@@ -1348,3 +1348,53 @@ def process_block(prefix, index, convert_norm):
1348
1348
converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1349
1349
1350
1350
return converted_state_dict
1351
+
1352
+
1353
+ def _convert_non_diffusers_wan_lora_to_diffusers (state_dict ):
1354
+ converted_state_dict = {}
1355
+ original_state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
1356
+
1357
+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict })
1358
+
1359
+ for i in range (num_blocks ):
1360
+ # Self-attention
1361
+ for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1362
+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_A.weight" ] = original_state_dict .pop (
1363
+ f"blocks.{ i } .self_attn.{ o } .lora_A.weight"
1364
+ )
1365
+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.weight" ] = original_state_dict .pop (
1366
+ f"blocks.{ i } .self_attn.{ o } .lora_B.weight"
1367
+ )
1368
+
1369
+ # Cross-attention
1370
+ for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1371
+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1372
+ f"blocks.{ i } .cross_attn.{ o } .lora_A.weight"
1373
+ )
1374
+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1375
+ f"blocks.{ i } .cross_attn.{ o } .lora_B.weight"
1376
+ )
1377
+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1378
+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1379
+ f"blocks.{ i } .cross_attn.{ o } .lora_A.weight"
1380
+ )
1381
+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1382
+ f"blocks.{ i } .cross_attn.{ o } .lora_B.weight"
1383
+ )
1384
+
1385
+ # FFN
1386
+ for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1387
+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1388
+ f"blocks.{ i } .{ o } .lora_A.weight"
1389
+ )
1390
+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1391
+ f"blocks.{ i } .{ o } .lora_B.weight"
1392
+ )
1393
+
1394
+ if len (original_state_dict ) > 0 :
1395
+ raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
1396
+
1397
+ for key in list (converted_state_dict .keys ()):
1398
+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1399
+
1400
+ return converted_state_dict
0 commit comments