Skip to content

Commit e4b056f

Browse files
authored
[LoRA] support wan i2v loras from the world. (#11025)
* support wan i2v loras from the world. * remove copied from. * upates * add lora.
1 parent 4e3ddd5 commit e4b056f

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

docs/source/en/api/pipelines/wan.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
# Wan
1616

17+
<div class="flex flex-wrap space-x-1">
18+
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
19+
</div>
20+
1721
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
1822

1923
<!-- TODO(aryan): update abstract once paper is out -->

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,53 @@ def process_block(prefix, index, convert_norm):
13481348
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
13491349

13501350
return converted_state_dict
1351+
1352+
1353+
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1354+
converted_state_dict = {}
1355+
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
1356+
1357+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1358+
1359+
for i in range(num_blocks):
1360+
# Self-attention
1361+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1362+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1363+
f"blocks.{i}.self_attn.{o}.lora_A.weight"
1364+
)
1365+
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1366+
f"blocks.{i}.self_attn.{o}.lora_B.weight"
1367+
)
1368+
1369+
# Cross-attention
1370+
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1371+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1372+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1373+
)
1374+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1375+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1376+
)
1377+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1378+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1379+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1380+
)
1381+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1382+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1383+
)
1384+
1385+
# FFN
1386+
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1387+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1388+
f"blocks.{i}.{o}.lora_A.weight"
1389+
)
1390+
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1391+
f"blocks.{i}.{o}.lora_B.weight"
1392+
)
1393+
1394+
if len(original_state_dict) > 0:
1395+
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
1396+
1397+
for key in list(converted_state_dict.keys()):
1398+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1399+
1400+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_convert_kohya_flux_lora_to_diffusers,
4343
_convert_non_diffusers_lora_to_diffusers,
4444
_convert_non_diffusers_lumina2_lora_to_diffusers,
45+
_convert_non_diffusers_wan_lora_to_diffusers,
4546
_convert_xlabs_flux_lora_to_diffusers,
4647
_maybe_map_sgm_blocks_to_diffusers,
4748
)
@@ -4111,7 +4112,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
41114112

41124113
@classmethod
41134114
@validate_hf_hub_args
4114-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
41154115
def lora_state_dict(
41164116
cls,
41174117
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4198,6 +4198,8 @@ def lora_state_dict(
41984198
user_agent=user_agent,
41994199
allow_pickle=allow_pickle,
42004200
)
4201+
if any(k.startswith("diffusion_model.") for k in state_dict):
4202+
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
42014203

42024204
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
42034205
if is_dora_scale_present:

0 commit comments

Comments
 (0)