Skip to content

Commit 02247d9

Browse files
pacman100younesbelkadapatrickvonplatensayakpaulBenjaminBossan
authored
PEFT Integration for Text Encoder to handle multiple alphas/ranks, disable/enable adapters and support for multiple adapters (huggingface#5147)
* more fixes * up * up * style * add in setup * oops * more changes * v1 rzfactor CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * few todos * protect torch import * style * fix fuse text encoder * Update src/diffusers/loaders.py Co-authored-by: Sayak Paul <[email protected]> * replace with `recurse_replace_peft_layers` * keep old modules for BC * adjustments on `adjust_lora_scale_text_encoder` * nit * move tests * add conversion utils * remove unneeded methods * use class method instead * oops * use `base_version` * fix examples * fix CI * fix weird error with python 3.8 * fix * better fix * style * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * add comment * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * conv2d support for recurse remove * added docstrings * more docstring * add deprecate * revert * try to fix merge conflicts * peft integration features for text encoder 1. support multiple rank/alpha values 2. support multiple active adapters 3. support disabling and enabling adapters * fix bug * fix code quality * Apply suggestions from code review Co-authored-by: Younes Belkada <[email protected]> * fix bugs * Apply suggestions from code review Co-authored-by: Younes Belkada <[email protected]> * address comments Co-Authored-By: Benjamin Bossan <[email protected]> Co-Authored-By: Patrick von Platen <[email protected]> * fix code quality * address comments * address comments * Apply suggestions from code review * find and replace --------- Co-authored-by: younesbelkada <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 940f941 commit 02247d9

File tree

4 files changed

+212
-19
lines changed

4 files changed

+212
-19
lines changed

src/diffusers/loaders.py

Lines changed: 105 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,23 @@
3535
convert_state_dict_to_diffusers,
3636
convert_state_dict_to_peft,
3737
deprecate,
38+
get_adapter_name,
39+
get_peft_kwargs,
3840
is_accelerate_available,
3941
is_omegaconf_available,
4042
is_peft_available,
4143
is_transformers_available,
4244
logging,
4345
recurse_remove_peft_layers,
46+
scale_lora_layers,
47+
set_adapter_layers,
48+
set_weights_and_activate_adapters,
4449
)
4550
from .utils.import_utils import BACKENDS_MAPPING
4651

4752

4853
if is_transformers_available():
49-
from transformers import CLIPTextModel, CLIPTextModelWithProjection
54+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel
5055

5156
if is_accelerate_available():
5257
from accelerate import init_empty_weights
@@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
11001105
num_fused_loras = 0
11011106
use_peft_backend = USE_PEFT_BACKEND
11021107

1103-
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1108+
def load_lora_weights(
1109+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1110+
):
11041111
"""
11051112
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
11061113
`self.text_encoder`.
@@ -1120,6 +1127,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
11201127
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
11211128
kwargs (`dict`, *optional*):
11221129
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
1130+
adapter_name (`str`, *optional*):
1131+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1132+
`default_{i}` where i is the total number of adapters being loaded.
11231133
"""
11241134
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
11251135
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
@@ -1143,6 +1153,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
11431153
text_encoder=self.text_encoder,
11441154
lora_scale=self.lora_scale,
11451155
low_cpu_mem_usage=low_cpu_mem_usage,
1156+
adapter_name=adapter_name,
11461157
_pipeline=self,
11471158
)
11481159

@@ -1500,6 +1511,7 @@ def load_lora_into_text_encoder(
15001511
prefix=None,
15011512
lora_scale=1.0,
15021513
low_cpu_mem_usage=None,
1514+
adapter_name=None,
15031515
_pipeline=None,
15041516
):
15051517
"""
@@ -1523,6 +1535,9 @@ def load_lora_into_text_encoder(
15231535
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
15241536
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
15251537
argument to `True` will raise an error.
1538+
adapter_name (`str`, *optional*):
1539+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1540+
`default_{i}` where i is the total number of adapters being loaded.
15261541
"""
15271542
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
15281543

@@ -1584,19 +1599,22 @@ def load_lora_into_text_encoder(
15841599
if cls.use_peft_backend:
15851600
from peft import LoraConfig
15861601

1587-
lora_rank = list(rank.values())[0]
1588-
# By definition, the scale should be alpha divided by rank.
1589-
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
1590-
alpha = lora_scale * lora_rank
1602+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict)
15911603

1592-
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
1593-
if patch_mlp:
1594-
target_modules += ["fc1", "fc2"]
1604+
lora_config = LoraConfig(**lora_config_kwargs)
15951605

1596-
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
1597-
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
1606+
# adapter_name
1607+
if adapter_name is None:
1608+
adapter_name = get_adapter_name(text_encoder)
15981609

1599-
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
1610+
# inject LoRA layers and load the state dict
1611+
text_encoder.load_adapter(
1612+
adapter_name=adapter_name,
1613+
adapter_state_dict=text_encoder_lora_state_dict,
1614+
peft_config=lora_config,
1615+
)
1616+
# scale LoRA layers with `lora_scale`
1617+
scale_lora_layers(text_encoder, weight=lora_scale)
16001618

16011619
is_model_cpu_offload = False
16021620
is_sequential_cpu_offload = False
@@ -2178,6 +2196,81 @@ def unfuse_text_encoder_lora(text_encoder):
21782196

21792197
self.num_fused_loras -= 1
21802198

2199+
def set_adapter_for_text_encoder(
2200+
self,
2201+
adapter_names: Union[List[str], str],
2202+
text_encoder: Optional[PreTrainedModel] = None,
2203+
text_encoder_weights: List[float] = None,
2204+
):
2205+
"""
2206+
Sets the adapter layers for the text encoder.
2207+
2208+
Args:
2209+
adapter_names (`List[str]` or `str`):
2210+
The names of the adapters to use.
2211+
text_encoder (`torch.nn.Module`, *optional*):
2212+
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
2213+
attribute.
2214+
text_encoder_weights (`List[float]`, *optional*):
2215+
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
2216+
"""
2217+
if not self.use_peft_backend:
2218+
raise ValueError("PEFT backend is required for this method.")
2219+
2220+
def process_weights(adapter_names, weights):
2221+
if weights is None:
2222+
weights = [1.0] * len(adapter_names)
2223+
elif isinstance(weights, float):
2224+
weights = [weights]
2225+
2226+
if len(adapter_names) != len(weights):
2227+
raise ValueError(
2228+
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
2229+
)
2230+
return weights
2231+
2232+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
2233+
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
2234+
text_encoder = text_encoder or getattr(self, "text_encoder", None)
2235+
if text_encoder is None:
2236+
raise ValueError(
2237+
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
2238+
)
2239+
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
2240+
2241+
def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
2242+
"""
2243+
Disables the LoRA layers for the text encoder.
2244+
2245+
Args:
2246+
text_encoder (`torch.nn.Module`, *optional*):
2247+
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
2248+
`text_encoder` attribute.
2249+
"""
2250+
if not self.use_peft_backend:
2251+
raise ValueError("PEFT backend is required for this method.")
2252+
2253+
text_encoder = text_encoder or getattr(self, "text_encoder", None)
2254+
if text_encoder is None:
2255+
raise ValueError("Text Encoder not found.")
2256+
set_adapter_layers(text_encoder, enabled=False)
2257+
2258+
def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
2259+
"""
2260+
Enables the LoRA layers for the text encoder.
2261+
2262+
Args:
2263+
text_encoder (`torch.nn.Module`, *optional*):
2264+
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
2265+
attribute.
2266+
"""
2267+
if not self.use_peft_backend:
2268+
raise ValueError("PEFT backend is required for this method.")
2269+
text_encoder = text_encoder or getattr(self, "text_encoder", None)
2270+
if text_encoder is None:
2271+
raise ValueError("Text Encoder not found.")
2272+
set_adapter_layers(self.text_encoder, enabled=True)
2273+
21812274

21822275
class FromSingleFileMixin:
21832276
"""

src/diffusers/models/lora.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,15 @@
1919
from torch import nn
2020

2121
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
22-
from ..utils import logging
22+
from ..utils import logging, scale_lora_layers
2323

2424

2525
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2626

2727

2828
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
2929
if use_peft_backend:
30-
from peft.tuners.lora import LoraLayer
31-
32-
for module in text_encoder.modules():
33-
if isinstance(module, LoraLayer):
34-
module.scaling[module.active_adapter] = lora_scale
30+
scale_lora_layers(text_encoder, weight=lora_scale)
3531
else:
3632
for _, attn_module in text_encoder_attn_modules(text_encoder):
3733
if isinstance(attn_module.q_proj, PatchedLoraProjection):

src/diffusers/utils/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,14 @@
8484
from .loading_utils import load_image
8585
from .logging import get_logger
8686
from .outputs import BaseOutput
87-
from .peft_utils import recurse_remove_peft_layers
87+
from .peft_utils import (
88+
get_adapter_name,
89+
get_peft_kwargs,
90+
recurse_remove_peft_layers,
91+
scale_lora_layers,
92+
set_adapter_layers,
93+
set_weights_and_activate_adapters,
94+
)
8895
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
8996
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
9097

src/diffusers/utils/peft_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""
1515
PEFT utilities: Utilities related to peft library
1616
"""
17+
import collections
18+
1719
from .import_utils import is_torch_available
1820

1921

@@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model):
6870
torch.cuda.empty_cache()
6971

7072
return model
73+
74+
75+
def scale_lora_layers(model, weight):
76+
"""
77+
Adjust the weightage given to the LoRA layers of the model.
78+
79+
Args:
80+
model (`torch.nn.Module`):
81+
The model to scale.
82+
weight (`float`):
83+
The weight to be given to the LoRA layers.
84+
"""
85+
from peft.tuners.tuners_utils import BaseTunerLayer
86+
87+
for module in model.modules():
88+
if isinstance(module, BaseTunerLayer):
89+
module.scale_layer(weight)
90+
91+
92+
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
93+
rank_pattern = {}
94+
alpha_pattern = {}
95+
r = lora_alpha = list(rank_dict.values())[0]
96+
if len(set(rank_dict.values())) > 1:
97+
# get the rank occuring the most number of times
98+
r = collections.Counter(rank_dict.values()).most_common()[0][0]
99+
100+
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
101+
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
102+
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
103+
104+
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
105+
# get the alpha occuring the most number of times
106+
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
107+
108+
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
109+
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
110+
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
111+
112+
# layer names without the Diffusers specific
113+
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
114+
115+
lora_config_kwargs = {
116+
"r": r,
117+
"lora_alpha": lora_alpha,
118+
"rank_pattern": rank_pattern,
119+
"alpha_pattern": alpha_pattern,
120+
"target_modules": target_modules,
121+
}
122+
return lora_config_kwargs
123+
124+
125+
def get_adapter_name(model):
126+
from peft.tuners.tuners_utils import BaseTunerLayer
127+
128+
for module in model.modules():
129+
if isinstance(module, BaseTunerLayer):
130+
return f"default_{len(module.r)}"
131+
return "default_0"
132+
133+
134+
def set_adapter_layers(model, enabled=True):
135+
from peft.tuners.tuners_utils import BaseTunerLayer
136+
137+
for module in model.modules():
138+
if isinstance(module, BaseTunerLayer):
139+
# The recent version of PEFT needs to call `enable_adapters` instead
140+
if hasattr(module, "enable_adapters"):
141+
module.enable_adapters(enabled=False)
142+
else:
143+
module.disable_adapters = True
144+
145+
146+
def set_weights_and_activate_adapters(model, adapter_names, weights):
147+
from peft.tuners.tuners_utils import BaseTunerLayer
148+
149+
# iterate over each adapter, make it active and set the corresponding scaling weight
150+
for adapter_name, weight in zip(adapter_names, weights):
151+
for module in model.modules():
152+
if isinstance(module, BaseTunerLayer):
153+
# For backward compatbility with previous PEFT versions
154+
if hasattr(module, "set_adapter"):
155+
module.set_adapter(adapter_name)
156+
else:
157+
module.active_adapter = adapter_name
158+
module.scale_layer(weight)
159+
160+
# set multiple active adapters
161+
for module in model.modules():
162+
if isinstance(module, BaseTunerLayer):
163+
# For backward compatbility with previous PEFT versions
164+
if hasattr(module, "set_adapter"):
165+
module.set_adapter(adapter_names)
166+
else:
167+
module.active_adapter = adapter_names

0 commit comments

Comments
 (0)