diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 80929a1c8a0b..df3aa6212f78 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1825,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" + has_time_projection_weight = any( + k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict + ) - diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))] - if diff_keys: - for diff_k in diff_keys: - param = original_state_dict[diff_k] - # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3, - # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond - # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It - # is okay to ignore because they do not affect the model output in a significant manner. - threshold = 1.6e-2 - absdiff = param.abs().max() - param.abs().min() - all_zero = torch.all(param == 0).item() - all_absdiff_lower_than_threshold = absdiff < threshold - if all_zero or all_absdiff_lower_than_threshold: - logger.debug( - f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold." - ) - original_state_dict.pop(diff_k) + for key in list(original_state_dict.keys()): + if key.endswith((".diff", ".diff_b")) and "norm" in key: + # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it + # in future if needed and they are not zeroed. + original_state_dict.pop(key) + logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.") + + if "time_projection" in key and not has_time_projection_weight: + # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where + # our lora config adds the time proj lora layers, but we don't have the weights for them. + # CausVid lora has the weight keys and the bias keys. + original_state_dict.pop(key) # For the `diff_b` keys, we treat them as lora_bias. # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index 740c00f941ed..a7eb74080499 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -28,6 +28,7 @@ from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import ( floats_tensor, + is_flaky, require_peft_backend, require_peft_version_greater, skip_mps, @@ -215,3 +216,7 @@ def test_lora_exclude_modules_wanvace(self): np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match.", ) + + @is_flaky + def test_simple_inference_with_text_denoiser_lora_and_scale(self): + super().test_simple_inference_with_text_denoiser_lora_and_scale()