Skip to content

Commit d87ce2c

Browse files
CyberVyhlkygithub-actions[bot]
authored
Fix missing **kwargs in lora_pipeline.py (#11011)
* Update lora_pipeline.py * Apply style fixes * fix-copies --------- Co-authored-by: hlky <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 36d0553 commit d87ce2c

File tree

1 file changed

+72
-24
lines changed

1 file changed

+72
-24
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,11 @@ def fuse_lora(
452452
```
453453
"""
454454
super().fuse_lora(
455-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
455+
components=components,
456+
lora_scale=lora_scale,
457+
safe_fusing=safe_fusing,
458+
adapter_names=adapter_names,
459+
**kwargs,
456460
)
457461

458462
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -473,7 +477,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs
473477
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
474478
LoRA parameters then it won't have any effect.
475479
"""
476-
super().unfuse_lora(components=components)
480+
super().unfuse_lora(components=components, **kwargs)
477481

478482

479483
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -892,7 +896,11 @@ def fuse_lora(
892896
```
893897
"""
894898
super().fuse_lora(
895-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
899+
components=components,
900+
lora_scale=lora_scale,
901+
safe_fusing=safe_fusing,
902+
adapter_names=adapter_names,
903+
**kwargs,
896904
)
897905

898906
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -913,7 +921,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc
913921
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
914922
LoRA parameters then it won't have any effect.
915923
"""
916-
super().unfuse_lora(components=components)
924+
super().unfuse_lora(components=components, **kwargs)
917925

918926

919927
class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1291,7 +1299,11 @@ def fuse_lora(
12911299
```
12921300
"""
12931301
super().fuse_lora(
1294-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1302+
components=components,
1303+
lora_scale=lora_scale,
1304+
safe_fusing=safe_fusing,
1305+
adapter_names=adapter_names,
1306+
**kwargs,
12951307
)
12961308

12971309
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
@@ -1313,7 +1325,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
13131325
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
13141326
LoRA parameters then it won't have any effect.
13151327
"""
1316-
super().unfuse_lora(components=components)
1328+
super().unfuse_lora(components=components, **kwargs)
13171329

13181330

13191331
class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1829,7 +1841,11 @@ def fuse_lora(
18291841
)
18301842

18311843
super().fuse_lora(
1832-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1844+
components=components,
1845+
lora_scale=lora_scale,
1846+
safe_fusing=safe_fusing,
1847+
adapter_names=adapter_names,
1848+
**kwargs,
18331849
)
18341850

18351851
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -1850,7 +1866,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
18501866
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
18511867
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
18521868

1853-
super().unfuse_lora(components=components)
1869+
super().unfuse_lora(components=components, **kwargs)
18541870

18551871
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
18561872
def unload_lora_weights(self, reset_to_overwritten_params=False):
@@ -2549,7 +2565,11 @@ def fuse_lora(
25492565
```
25502566
"""
25512567
super().fuse_lora(
2552-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2568+
components=components,
2569+
lora_scale=lora_scale,
2570+
safe_fusing=safe_fusing,
2571+
adapter_names=adapter_names,
2572+
**kwargs,
25532573
)
25542574

25552575
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2567,7 +2587,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
25672587
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
25682588
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
25692589
"""
2570-
super().unfuse_lora(components=components)
2590+
super().unfuse_lora(components=components, **kwargs)
25712591

25722592

25732593
class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -2853,7 +2873,11 @@ def fuse_lora(
28532873
```
28542874
"""
28552875
super().fuse_lora(
2856-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
2876+
components=components,
2877+
lora_scale=lora_scale,
2878+
safe_fusing=safe_fusing,
2879+
adapter_names=adapter_names,
2880+
**kwargs,
28572881
)
28582882

28592883
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -2872,7 +2896,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
28722896
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
28732897
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
28742898
"""
2875-
super().unfuse_lora(components=components)
2899+
super().unfuse_lora(components=components, **kwargs)
28762900

28772901

28782902
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3158,7 +3182,11 @@ def fuse_lora(
31583182
```
31593183
"""
31603184
super().fuse_lora(
3161-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3185+
components=components,
3186+
lora_scale=lora_scale,
3187+
safe_fusing=safe_fusing,
3188+
adapter_names=adapter_names,
3189+
**kwargs,
31623190
)
31633191

31643192
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3177,7 +3205,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
31773205
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
31783206
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
31793207
"""
3180-
super().unfuse_lora(components=components)
3208+
super().unfuse_lora(components=components, **kwargs)
31813209

31823210

31833211
class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3463,7 +3491,11 @@ def fuse_lora(
34633491
```
34643492
"""
34653493
super().fuse_lora(
3466-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3494+
components=components,
3495+
lora_scale=lora_scale,
3496+
safe_fusing=safe_fusing,
3497+
adapter_names=adapter_names,
3498+
**kwargs,
34673499
)
34683500

34693501
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3482,7 +3514,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
34823514
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
34833515
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
34843516
"""
3485-
super().unfuse_lora(components=components)
3517+
super().unfuse_lora(components=components, **kwargs)
34863518

34873519

34883520
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3771,7 +3803,11 @@ def fuse_lora(
37713803
```
37723804
"""
37733805
super().fuse_lora(
3774-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
3806+
components=components,
3807+
lora_scale=lora_scale,
3808+
safe_fusing=safe_fusing,
3809+
adapter_names=adapter_names,
3810+
**kwargs,
37753811
)
37763812

37773813
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3790,7 +3826,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
37903826
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
37913827
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
37923828
"""
3793-
super().unfuse_lora(components=components)
3829+
super().unfuse_lora(components=components, **kwargs)
37943830

37953831

37963832
class Lumina2LoraLoaderMixin(LoraBaseMixin):
@@ -4080,7 +4116,11 @@ def fuse_lora(
40804116
```
40814117
"""
40824118
super().fuse_lora(
4083-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4119+
components=components,
4120+
lora_scale=lora_scale,
4121+
safe_fusing=safe_fusing,
4122+
adapter_names=adapter_names,
4123+
**kwargs,
40844124
)
40854125

40864126
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
@@ -4099,7 +4139,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
40994139
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
41004140
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
41014141
"""
4102-
super().unfuse_lora(components=components)
4142+
super().unfuse_lora(components=components, **kwargs)
41034143

41044144

41054145
class WanLoraLoaderMixin(LoraBaseMixin):
@@ -4386,7 +4426,11 @@ def fuse_lora(
43864426
```
43874427
"""
43884428
super().fuse_lora(
4389-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4429+
components=components,
4430+
lora_scale=lora_scale,
4431+
safe_fusing=safe_fusing,
4432+
adapter_names=adapter_names,
4433+
**kwargs,
43904434
)
43914435

43924436
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4405,7 +4449,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
44054449
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
44064450
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
44074451
"""
4408-
super().unfuse_lora(components=components)
4452+
super().unfuse_lora(components=components, **kwargs)
44094453

44104454

44114455
class CogView4LoraLoaderMixin(LoraBaseMixin):
@@ -4691,7 +4735,11 @@ def fuse_lora(
46914735
```
46924736
"""
46934737
super().fuse_lora(
4694-
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4738+
components=components,
4739+
lora_scale=lora_scale,
4740+
safe_fusing=safe_fusing,
4741+
adapter_names=adapter_names,
4742+
**kwargs,
46954743
)
46964744

46974745
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4710,7 +4758,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
47104758
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
47114759
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
47124760
"""
4713-
super().unfuse_lora(components=components)
4761+
super().unfuse_lora(components=components, **kwargs)
47144762

47154763

47164764
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):

0 commit comments

Comments
 (0)