|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2023 The HuggingFace Inc. team. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +from typing import List, Union |
| 16 | + |
| 17 | +from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available |
| 18 | + |
| 19 | + |
| 20 | +class PeftAdapterMixin: |
| 21 | + """ |
| 22 | + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For |
| 23 | + more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT |
| 24 | + library: https://huggingface.co/docs/peft/index. |
| 25 | +
|
| 26 | +
|
| 27 | + With this mixin, if the correct PEFT version is installed, it is possible to: |
| 28 | +
|
| 29 | + - Attach new adapters in the model. |
| 30 | + - Attach multiple adapters and iteratively activate / deactivate them. |
| 31 | + - Activate / deactivate all adapters from the model. |
| 32 | + - Get a list of the active adapters. |
| 33 | + """ |
| 34 | + |
| 35 | + _hf_peft_config_loaded = False |
| 36 | + |
| 37 | + def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: |
| 38 | + r""" |
| 39 | + Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned |
| 40 | + to the adapter to follow the convention of the PEFT library. |
| 41 | +
|
| 42 | + If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT |
| 43 | + [documentation](https://huggingface.co/docs/peft). |
| 44 | +
|
| 45 | + Args: |
| 46 | + adapter_config (`[~peft.PeftConfig]`): |
| 47 | + The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt |
| 48 | + methods. |
| 49 | + adapter_name (`str`, *optional*, defaults to `"default"`): |
| 50 | + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. |
| 51 | + """ |
| 52 | + check_peft_version(min_version=MIN_PEFT_VERSION) |
| 53 | + |
| 54 | + if not is_peft_available(): |
| 55 | + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") |
| 56 | + |
| 57 | + from peft import PeftConfig, inject_adapter_in_model |
| 58 | + |
| 59 | + if not self._hf_peft_config_loaded: |
| 60 | + self._hf_peft_config_loaded = True |
| 61 | + elif adapter_name in self.peft_config: |
| 62 | + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") |
| 63 | + |
| 64 | + if not isinstance(adapter_config, PeftConfig): |
| 65 | + raise ValueError( |
| 66 | + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." |
| 67 | + ) |
| 68 | + |
| 69 | + # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is |
| 70 | + # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. |
| 71 | + adapter_config.base_model_name_or_path = None |
| 72 | + inject_adapter_in_model(adapter_config, self, adapter_name) |
| 73 | + self.set_adapter(adapter_name) |
| 74 | + |
| 75 | + def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: |
| 76 | + """ |
| 77 | + Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. |
| 78 | +
|
| 79 | + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT |
| 80 | + official documentation: https://huggingface.co/docs/peft |
| 81 | +
|
| 82 | + Args: |
| 83 | + adapter_name (Union[str, List[str]])): |
| 84 | + The list of adapters to set or the adapter name in case of single adapter. |
| 85 | + """ |
| 86 | + check_peft_version(min_version=MIN_PEFT_VERSION) |
| 87 | + |
| 88 | + if not self._hf_peft_config_loaded: |
| 89 | + raise ValueError("No adapter loaded. Please load an adapter first.") |
| 90 | + |
| 91 | + if isinstance(adapter_name, str): |
| 92 | + adapter_name = [adapter_name] |
| 93 | + |
| 94 | + missing = set(adapter_name) - set(self.peft_config) |
| 95 | + if len(missing) > 0: |
| 96 | + raise ValueError( |
| 97 | + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." |
| 98 | + f" current loaded adapters are: {list(self.peft_config.keys())}" |
| 99 | + ) |
| 100 | + |
| 101 | + from peft.tuners.tuners_utils import BaseTunerLayer |
| 102 | + |
| 103 | + _adapters_has_been_set = False |
| 104 | + |
| 105 | + for _, module in self.named_modules(): |
| 106 | + if isinstance(module, BaseTunerLayer): |
| 107 | + if hasattr(module, "set_adapter"): |
| 108 | + module.set_adapter(adapter_name) |
| 109 | + # Previous versions of PEFT does not support multi-adapter inference |
| 110 | + elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: |
| 111 | + raise ValueError( |
| 112 | + "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." |
| 113 | + " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" |
| 114 | + ) |
| 115 | + else: |
| 116 | + module.active_adapter = adapter_name |
| 117 | + _adapters_has_been_set = True |
| 118 | + |
| 119 | + if not _adapters_has_been_set: |
| 120 | + raise ValueError( |
| 121 | + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." |
| 122 | + ) |
| 123 | + |
| 124 | + def disable_adapters(self) -> None: |
| 125 | + r""" |
| 126 | + Disable all adapters attached to the model and fallback to inference with the base model only. |
| 127 | +
|
| 128 | + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT |
| 129 | + official documentation: https://huggingface.co/docs/peft |
| 130 | + """ |
| 131 | + check_peft_version(min_version=MIN_PEFT_VERSION) |
| 132 | + |
| 133 | + if not self._hf_peft_config_loaded: |
| 134 | + raise ValueError("No adapter loaded. Please load an adapter first.") |
| 135 | + |
| 136 | + from peft.tuners.tuners_utils import BaseTunerLayer |
| 137 | + |
| 138 | + for _, module in self.named_modules(): |
| 139 | + if isinstance(module, BaseTunerLayer): |
| 140 | + if hasattr(module, "enable_adapters"): |
| 141 | + module.enable_adapters(enabled=False) |
| 142 | + else: |
| 143 | + # support for older PEFT versions |
| 144 | + module.disable_adapters = True |
| 145 | + |
| 146 | + def enable_adapters(self) -> None: |
| 147 | + """ |
| 148 | + Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the |
| 149 | + list of adapters to enable. |
| 150 | +
|
| 151 | + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT |
| 152 | + official documentation: https://huggingface.co/docs/peft |
| 153 | + """ |
| 154 | + check_peft_version(min_version=MIN_PEFT_VERSION) |
| 155 | + |
| 156 | + if not self._hf_peft_config_loaded: |
| 157 | + raise ValueError("No adapter loaded. Please load an adapter first.") |
| 158 | + |
| 159 | + from peft.tuners.tuners_utils import BaseTunerLayer |
| 160 | + |
| 161 | + for _, module in self.named_modules(): |
| 162 | + if isinstance(module, BaseTunerLayer): |
| 163 | + if hasattr(module, "enable_adapters"): |
| 164 | + module.enable_adapters(enabled=True) |
| 165 | + else: |
| 166 | + # support for older PEFT versions |
| 167 | + module.disable_adapters = False |
| 168 | + |
| 169 | + def active_adapters(self) -> List[str]: |
| 170 | + """ |
| 171 | + Gets the current list of active adapters of the model. |
| 172 | +
|
| 173 | + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT |
| 174 | + official documentation: https://huggingface.co/docs/peft |
| 175 | + """ |
| 176 | + check_peft_version(min_version=MIN_PEFT_VERSION) |
| 177 | + |
| 178 | + if not is_peft_available(): |
| 179 | + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") |
| 180 | + |
| 181 | + if not self._hf_peft_config_loaded: |
| 182 | + raise ValueError("No adapter loaded. Please load an adapter first.") |
| 183 | + |
| 184 | + from peft.tuners.tuners_utils import BaseTunerLayer |
| 185 | + |
| 186 | + for _, module in self.named_modules(): |
| 187 | + if isinstance(module, BaseTunerLayer): |
| 188 | + return module.active_adapter |
0 commit comments