@@ -1495,10 +1495,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
14951495
14961496 @classmethod
14971497 @validate_hf_hub_args
1498- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
14991498 def lora_state_dict (
15001499 cls ,
15011500 pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
1501+ return_alphas : bool = False ,
15021502 ** kwargs ,
15031503 ):
15041504 r"""
@@ -1583,7 +1583,26 @@ def lora_state_dict(
15831583 allow_pickle = allow_pickle ,
15841584 )
15851585
1586- return state_dict
1586+ # For state dicts like
1587+ # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1588+ keys = list (state_dict .keys ())
1589+ network_alphas = {}
1590+ for k in keys :
1591+ if "alpha" in k :
1592+ alpha_value = state_dict .get (k )
1593+ if (torch .is_tensor (alpha_value ) and torch .is_floating_point (alpha_value )) or isinstance (
1594+ alpha_value , float
1595+ ):
1596+ network_alphas [k ] = state_dict .pop (k )
1597+ else :
1598+ raise ValueError (
1599+ f"The alpha key ({ k } ) seems to be incorrect. If you think this error is unexpected, please open as issue."
1600+ )
1601+
1602+ if return_alphas :
1603+ return state_dict , network_alphas
1604+ else :
1605+ return state_dict
15871606
15881607 def load_lora_weights (
15891608 self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
@@ -1617,14 +1636,17 @@ def load_lora_weights(
16171636 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
16181637
16191638 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1620- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
1639+ state_dict , network_alphas = self .lora_state_dict (
1640+ pretrained_model_name_or_path_or_dict , return_alphas = True , ** kwargs
1641+ )
16211642
16221643 is_correct_format = all ("lora" in key or "dora_scale" in key for key in state_dict .keys ())
16231644 if not is_correct_format :
16241645 raise ValueError ("Invalid LoRA checkpoint." )
16251646
16261647 self .load_lora_into_transformer (
16271648 state_dict ,
1649+ network_alphas = network_alphas ,
16281650 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
16291651 adapter_name = adapter_name ,
16301652 _pipeline = self ,
@@ -1634,7 +1656,7 @@ def load_lora_weights(
16341656 if len (text_encoder_state_dict ) > 0 :
16351657 self .load_lora_into_text_encoder (
16361658 text_encoder_state_dict ,
1637- network_alphas = None ,
1659+ network_alphas = network_alphas ,
16381660 text_encoder = self .text_encoder ,
16391661 prefix = "text_encoder" ,
16401662 lora_scale = self .lora_scale ,
@@ -1643,8 +1665,7 @@ def load_lora_weights(
16431665 )
16441666
16451667 @classmethod
1646- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
1647- def load_lora_into_transformer (cls , state_dict , transformer , adapter_name = None , _pipeline = None ):
1668+ def load_lora_into_transformer (cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None ):
16481669 """
16491670 This will load the LoRA layers specified in `state_dict` into `transformer`.
16501671
@@ -1653,6 +1674,10 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
16531674 A standard state dict containing the lora layer parameters. The keys can either be indexed directly
16541675 into the unet or prefixed with an additional `unet` which can be used to distinguish between text
16551676 encoder lora layers.
1677+ network_alphas (`Dict[str, float]`):
1678+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1679+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1680+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
16561681 transformer (`SD3Transformer2DModel`):
16571682 The Transformer model to load the LoRA layers into.
16581683 adapter_name (`str`, *optional*):
@@ -1684,7 +1709,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
16841709 if "lora_B" in key :
16851710 rank [key ] = val .shape [1 ]
16861711
1687- lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = None , peft_state_dict = state_dict )
1712+ if network_alphas is not None and len (network_alphas ) >= 1 :
1713+ prefix = cls .transformer_name
1714+ alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
1715+ network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
1716+
1717+ lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict )
16881718 if "use_dora" in lora_config_kwargs :
16891719 if lora_config_kwargs ["use_dora" ] and is_peft_version ("<" , "0.9.0" ):
16901720 raise ValueError (
0 commit comments