Skip to content

Commit 585f941

Browse files
authored
[Core] introduce PeftAdapterMixin module. (huggingface#6416)
* introduce integrations module. * remove duplicate methods. * better imports. * move to loaders.py * remove peftadaptermixin from modelmixin. * add: peftadaptermixin selectively. * add: entry to _toctree * Empty-Commit
1 parent 86a2676 commit 585f941

File tree

9 files changed

+228
-158
lines changed

9 files changed

+228
-158
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@
212212
title: Textual Inversion
213213
- local: api/loaders/unet
214214
title: UNet
215+
- local: api/loaders/peft
216+
title: PEFT
215217
title: Loaders
216218
- sections:
217219
- local: api/models/overview

docs/source/en/api/loaders/peft.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# PEFT
14+
15+
Diffusers supports working with adapters (such as [LoRA](../../using-diffusers/loading_adapters)) via the [`peft` library](https://huggingface.co/docs/peft/index). We provide a `PeftAdapterMixin` class to handle this for modeling classes in Diffusers (such as [`UNet2DConditionModel`]).
16+
17+
<Tip>
18+
19+
Refer to [this doc](../../tutorials/using_peft_for_inference.md) to get an overview of how to work with `peft` in Diffusers for inference.
20+
21+
</Tip>
22+
23+
## PeftAdapterMixin
24+
25+
[[autodoc]] loaders.peft.PeftAdapterMixin

src/diffusers/loaders/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4-
from ..utils.import_utils import is_torch_available, is_transformers_available
4+
from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
55

66

77
def text_encoder_lora_state_dict(text_encoder):
@@ -64,6 +64,8 @@ def text_encoder_attn_modules(text_encoder):
6464
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
6565
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
6666

67+
_import_structure["peft"] = ["PeftAdapterMixin"]
68+
6769

6870
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
6971
if is_torch_available():
@@ -76,6 +78,8 @@ def text_encoder_attn_modules(text_encoder):
7678
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
7779
from .single_file import FromSingleFileMixin
7880
from .textual_inversion import TextualInversionLoaderMixin
81+
82+
from .peft import PeftAdapterMixin
7983
else:
8084
import sys
8185

src/diffusers/loaders/peft.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# coding=utf-8
2+
# Copyright 2023 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import List, Union
16+
17+
from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available
18+
19+
20+
class PeftAdapterMixin:
21+
"""
22+
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
23+
more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT
24+
library: https://huggingface.co/docs/peft/index.
25+
26+
27+
With this mixin, if the correct PEFT version is installed, it is possible to:
28+
29+
- Attach new adapters in the model.
30+
- Attach multiple adapters and iteratively activate / deactivate them.
31+
- Activate / deactivate all adapters from the model.
32+
- Get a list of the active adapters.
33+
"""
34+
35+
_hf_peft_config_loaded = False
36+
37+
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
38+
r"""
39+
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
40+
to the adapter to follow the convention of the PEFT library.
41+
42+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
43+
[documentation](https://huggingface.co/docs/peft).
44+
45+
Args:
46+
adapter_config (`[~peft.PeftConfig]`):
47+
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
48+
methods.
49+
adapter_name (`str`, *optional*, defaults to `"default"`):
50+
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
51+
"""
52+
check_peft_version(min_version=MIN_PEFT_VERSION)
53+
54+
if not is_peft_available():
55+
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
56+
57+
from peft import PeftConfig, inject_adapter_in_model
58+
59+
if not self._hf_peft_config_loaded:
60+
self._hf_peft_config_loaded = True
61+
elif adapter_name in self.peft_config:
62+
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
63+
64+
if not isinstance(adapter_config, PeftConfig):
65+
raise ValueError(
66+
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
67+
)
68+
69+
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
70+
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
71+
adapter_config.base_model_name_or_path = None
72+
inject_adapter_in_model(adapter_config, self, adapter_name)
73+
self.set_adapter(adapter_name)
74+
75+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
76+
"""
77+
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
78+
79+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
80+
official documentation: https://huggingface.co/docs/peft
81+
82+
Args:
83+
adapter_name (Union[str, List[str]])):
84+
The list of adapters to set or the adapter name in case of single adapter.
85+
"""
86+
check_peft_version(min_version=MIN_PEFT_VERSION)
87+
88+
if not self._hf_peft_config_loaded:
89+
raise ValueError("No adapter loaded. Please load an adapter first.")
90+
91+
if isinstance(adapter_name, str):
92+
adapter_name = [adapter_name]
93+
94+
missing = set(adapter_name) - set(self.peft_config)
95+
if len(missing) > 0:
96+
raise ValueError(
97+
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
98+
f" current loaded adapters are: {list(self.peft_config.keys())}"
99+
)
100+
101+
from peft.tuners.tuners_utils import BaseTunerLayer
102+
103+
_adapters_has_been_set = False
104+
105+
for _, module in self.named_modules():
106+
if isinstance(module, BaseTunerLayer):
107+
if hasattr(module, "set_adapter"):
108+
module.set_adapter(adapter_name)
109+
# Previous versions of PEFT does not support multi-adapter inference
110+
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
111+
raise ValueError(
112+
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
113+
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
114+
)
115+
else:
116+
module.active_adapter = adapter_name
117+
_adapters_has_been_set = True
118+
119+
if not _adapters_has_been_set:
120+
raise ValueError(
121+
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
122+
)
123+
124+
def disable_adapters(self) -> None:
125+
r"""
126+
Disable all adapters attached to the model and fallback to inference with the base model only.
127+
128+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
129+
official documentation: https://huggingface.co/docs/peft
130+
"""
131+
check_peft_version(min_version=MIN_PEFT_VERSION)
132+
133+
if not self._hf_peft_config_loaded:
134+
raise ValueError("No adapter loaded. Please load an adapter first.")
135+
136+
from peft.tuners.tuners_utils import BaseTunerLayer
137+
138+
for _, module in self.named_modules():
139+
if isinstance(module, BaseTunerLayer):
140+
if hasattr(module, "enable_adapters"):
141+
module.enable_adapters(enabled=False)
142+
else:
143+
# support for older PEFT versions
144+
module.disable_adapters = True
145+
146+
def enable_adapters(self) -> None:
147+
"""
148+
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
149+
list of adapters to enable.
150+
151+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
152+
official documentation: https://huggingface.co/docs/peft
153+
"""
154+
check_peft_version(min_version=MIN_PEFT_VERSION)
155+
156+
if not self._hf_peft_config_loaded:
157+
raise ValueError("No adapter loaded. Please load an adapter first.")
158+
159+
from peft.tuners.tuners_utils import BaseTunerLayer
160+
161+
for _, module in self.named_modules():
162+
if isinstance(module, BaseTunerLayer):
163+
if hasattr(module, "enable_adapters"):
164+
module.enable_adapters(enabled=True)
165+
else:
166+
# support for older PEFT versions
167+
module.disable_adapters = False
168+
169+
def active_adapters(self) -> List[str]:
170+
"""
171+
Gets the current list of active adapters of the model.
172+
173+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
174+
official documentation: https://huggingface.co/docs/peft
175+
"""
176+
check_peft_version(min_version=MIN_PEFT_VERSION)
177+
178+
if not is_peft_available():
179+
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
180+
181+
if not self._hf_peft_config_loaded:
182+
raise ValueError("No adapter loaded. Please load an adapter first.")
183+
184+
from peft.tuners.tuners_utils import BaseTunerLayer
185+
186+
for _, module in self.named_modules():
187+
if isinstance(module, BaseTunerLayer):
188+
return module.active_adapter

0 commit comments

Comments
 (0)