|
| 1 | +# Copyright 2023 The KerasNLP Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import re |
| 15 | + |
| 16 | +from keras_nlp.api_export import keras_nlp_export |
| 17 | +from keras_nlp.backend import config |
| 18 | +from keras_nlp.backend import keras |
| 19 | +from keras_nlp.backend import ops |
| 20 | + |
| 21 | + |
| 22 | +def validate_einsum_equation(equation): |
| 23 | + # For simplicity, we greatly restrict possible einsum equations. The final |
| 24 | + # axis of the input must be the first axis of our kernel, and must not |
| 25 | + # appear in our output. |
| 26 | + left, right, output = re.split(",|->", equation) |
| 27 | + valid = ( |
| 28 | + left[-1] == right[0] |
| 29 | + and left[-1] not in output |
| 30 | + and set(left[:-1]).isdisjoint(set(right[1:])) |
| 31 | + ) |
| 32 | + if not valid: |
| 33 | + raise ValueError( |
| 34 | + "When passing a `EinsumDense` layer to a `LoraDense` layer, the " |
| 35 | + "einsum `equation` must always have the form `*x,x*->*`, where " |
| 36 | + "each `*` can be any sequence. Conceptually, the `equation` should " |
| 37 | + "always represent a dense matmul on the last axis of the input. " |
| 38 | + f"Received invalid equation `'{equation}'`." |
| 39 | + ) |
| 40 | + |
| 41 | + |
| 42 | +@keras_nlp_export("keras_nlp.layers.LoraDense") |
| 43 | +class LoraDense(keras.layers.Layer): |
| 44 | + """A LoRA adapter layer for a dense input layer. |
| 45 | +
|
| 46 | + This layer implements a low-rank decomposition of a dense transformation, as |
| 47 | + described in [LoRA: Low-Rank Adaptation Of Large Language Models](https://arxiv.org/pdf/2106.09685.pdf) |
| 48 | + This layer can be used to replace a dense layer with a layer whose |
| 49 | + parameters are mostly frozen. |
| 50 | +
|
| 51 | + By default, this layer takes in an `inner_dense` layer, freezes its |
| 52 | + parameters, and builds a low-rank decomposed update to sum with the original |
| 53 | + `inner_dense` output. These update parameters can be merged back into the |
| 54 | + `inner_dense` kernel by calling `merge_weights()`. |
| 55 | +
|
| 56 | + Args: |
| 57 | + inner_dense: A `keras.layers.Dense` or `keras.layers.EinsumDense`. |
| 58 | + The inner dense layer to freeze and wrap with the `LoraDense` |
| 59 | + layer. Note that for `EinsumDense` layers, the einsum equation must |
| 60 | + represent a dense transformation on the last axis of the input, |
| 61 | + though adding new axes to the output (e.g. a multi-head axis) is |
| 62 | + allowed. |
| 63 | + rank: int. The inner rank of the decomposed dense transformation. The |
| 64 | + lower this number, the fewer trainable parameters the layer will |
| 65 | + have. |
| 66 | + alpha: float. A constant value used for scaling the lora update. The |
| 67 | + lora update to the original dense transformation will be scaled by |
| 68 | + `alpha / rank`. |
| 69 | + lora_a_initializer: The initializer to use for the inner projection |
| 70 | + from layer inputs to the inner `rank` intermediate outputs. |
| 71 | + freeze_kernel: If true, the kernel of the inner dense layer will have |
| 72 | + `trainable` set to `False`. |
| 73 | + freeze_bias: If true, the kernel of the inner dense layer will have |
| 74 | + `trainable` set to `False`. |
| 75 | + **kwargs: other keyword arguments. |
| 76 | +
|
| 77 | + Examples: |
| 78 | +
|
| 79 | + Wrap a `Dense` layer. |
| 80 | + ```python |
| 81 | + batch_size, feature_size = 4, 16 |
| 82 | + rank = 4 |
| 83 | + inputs = np.random.uniform(size=(batch_size, feature_size)) |
| 84 | + inner_dense = keras.layers.Dense(feature_size) |
| 85 | + lora_dense = keras_nlp.layers.LoraDense(inner_dense, rank=4) |
| 86 | + # Output with inner dense begins equal. |
| 87 | + assert np.allclose(inner_dense(inputs), lora_dense(inputs)) |
| 88 | +
|
| 89 | + # Add some random updates to the lora parameters. |
| 90 | + lora_dense.lora_a.assign(np.random.uniform(size=(feature_size, rank))) |
| 91 | + lora_dense.lora_b.assign(np.random.uniform(size=(rank, feature_size))) |
| 92 | + assert not np.allclose(inner_dense(inputs), lora_dense(inputs)) |
| 93 | +
|
| 94 | + # Merge the lora dense and output |
| 95 | + lora_dense.merge_weights() |
| 96 | + assert np.allclose(inner_dense(inputs), lora_dense(inputs)) |
| 97 | + ``` |
| 98 | +
|
| 99 | + Wrap an `EinsumDense` layer with a multi-head projection. |
| 100 | + ```python |
| 101 | + batch_size, sequence_length, feature_size = 4, 10, 16 |
| 102 | + num_heads = 2 |
| 103 | + rank = 4 |
| 104 | + inputs = np.random.uniform(size=(batch_size, sequence_length, feature_size)) |
| 105 | + inner_dense = keras.layers.EinsumDense( |
| 106 | + "abc,cde->abde", |
| 107 | + output_shape=(sequence_length, num_heads, feature_size // num_heads), |
| 108 | + ) |
| 109 | + lora_dense = keras_nlp.layers.LoraDense(inner_dense, rank=4) |
| 110 | + # Output shape (4, 10, 2, 8) |
| 111 | + lora_dense(inputs) |
| 112 | + ``` |
| 113 | + """ |
| 114 | + |
| 115 | + def __init__( |
| 116 | + self, |
| 117 | + inner_dense, |
| 118 | + rank=8, |
| 119 | + alpha=8.0, |
| 120 | + lora_a_initializer="variance_scaling", |
| 121 | + freeze_kernel=True, |
| 122 | + freeze_bias=True, |
| 123 | + **kwargs, |
| 124 | + ): |
| 125 | + # Default to the same dtype as our inner layer. |
| 126 | + if "dtype" not in kwargs: |
| 127 | + kwargs["dtype"] = inner_dense.dtype_policy |
| 128 | + super().__init__(**kwargs) |
| 129 | + |
| 130 | + if not config.multi_backend(): |
| 131 | + raise ValueError( |
| 132 | + "Lora only works with multi-backend Keras 3. Please set the " |
| 133 | + "`KERAS_BACKEND` environment variable to use this API." |
| 134 | + ) |
| 135 | + |
| 136 | + if isinstance(inner_dense, keras.layers.Dense): |
| 137 | + self.inner_dense = inner_dense |
| 138 | + elif isinstance(inner_dense, keras.layers.EinsumDense): |
| 139 | + self.inner_dense = inner_dense |
| 140 | + validate_einsum_equation(inner_dense.equation) |
| 141 | + else: |
| 142 | + raise ValueError( |
| 143 | + "Only `Dense` and `EinsumDense` inner layers are supported. " |
| 144 | + f"Received: inner_dense={inner_dense}" |
| 145 | + ) |
| 146 | + |
| 147 | + self.rank = rank |
| 148 | + self.alpha = alpha |
| 149 | + self.scale = alpha / rank |
| 150 | + self.freeze_kernel = freeze_kernel |
| 151 | + self.freeze_bias = freeze_bias |
| 152 | + self.lora_a_initializer = keras.initializers.get(lora_a_initializer) |
| 153 | + |
| 154 | + if inner_dense.built: |
| 155 | + self.build_from_config(inner_dense.get_build_config()) |
| 156 | + |
| 157 | + def build(self, inputs_shape): |
| 158 | + if not self.inner_dense.built: |
| 159 | + self.inner_dense.build(inputs_shape) |
| 160 | + |
| 161 | + if self.freeze_kernel and self.inner_dense.kernel is not None: |
| 162 | + self.inner_dense.kernel.trainable = False |
| 163 | + |
| 164 | + if self.freeze_bias and self.inner_dense.bias is not None: |
| 165 | + self.inner_dense.bias.trainable = False |
| 166 | + |
| 167 | + input_dim = inputs_shape[-1] |
| 168 | + self.lora_a = self.add_weight( |
| 169 | + name="lora_a", |
| 170 | + shape=(input_dim, self.rank), |
| 171 | + initializer=self.lora_a_initializer, |
| 172 | + ) |
| 173 | + kernel_shape = self.inner_dense.kernel.shape |
| 174 | + self.lora_b = self.add_weight( |
| 175 | + name="lora_b", |
| 176 | + shape=(self.rank,) + kernel_shape[1:], |
| 177 | + initializer="zeros", |
| 178 | + ) |
| 179 | + self.built = True |
| 180 | + |
| 181 | + def merge_weights(self): |
| 182 | + """Merge lora updates into the wrapped dense layer. |
| 183 | +
|
| 184 | + This function should only be called outside of any compiled context |
| 185 | + (e.g. not during `fit()`, `predict()` or `evaluate()`). It will merge |
| 186 | + the updates from the lora layers into the original dense layer, and |
| 187 | + re-initialize the lora variables. |
| 188 | + """ |
| 189 | + if not self.built: |
| 190 | + return |
| 191 | + |
| 192 | + # Compute matmul of lora_a and lora_b to get a kernel sized update. |
| 193 | + update = ops.tensordot(self.lora_a, self.lora_b, axes=([-1], [0])) |
| 194 | + update = update * ops.cast(self.scale, update.dtype) |
| 195 | + # Add lora updates back into the inner dense kernel. |
| 196 | + self.inner_dense.kernel.assign_add(update) |
| 197 | + # Re-initialize lora weights. |
| 198 | + self.lora_a.assign( |
| 199 | + self.lora_a_initializer(self.lora_a.shape, self.lora_a.dtype) |
| 200 | + ) |
| 201 | + self.lora_b.assign(ops.zeros_like(self.lora_b)) |
| 202 | + |
| 203 | + def call(self, inputs): |
| 204 | + original_output = self.inner_dense(inputs) |
| 205 | + # Compute the low-rank intermediate output. |
| 206 | + update = ops.matmul(inputs, self.lora_a) |
| 207 | + # Use the matching dense computation for a Dense or EinsumDense. |
| 208 | + if isinstance(self.inner_dense, keras.layers.Dense): |
| 209 | + update = ops.matmul(update, self.lora_b) |
| 210 | + else: |
| 211 | + update = ops.einsum(self.inner_dense.equation, update, self.lora_b) |
| 212 | + # Scale and sum the lora update with the original frozen output. |
| 213 | + return original_output + update * ops.cast(self.scale, update.dtype) |
| 214 | + |
| 215 | + @classmethod |
| 216 | + def from_config(cls, config): |
| 217 | + config["inner_dense"] = keras.layers.deserialize(config["inner_dense"]) |
| 218 | + return super().from_config(config) |
| 219 | + |
| 220 | + def get_config(self): |
| 221 | + config = super().get_config() |
| 222 | + config.update( |
| 223 | + { |
| 224 | + "inner_dense": keras.layers.serialize(self.inner_dense), |
| 225 | + "rank": self.rank, |
| 226 | + "alpha": self.alpha, |
| 227 | + "lora_a_initializer": keras.initializers.serialize( |
| 228 | + self.lora_a_initializer |
| 229 | + ), |
| 230 | + "freeze_kernel": self.freeze_kernel, |
| 231 | + "freeze_bias": self.freeze_bias, |
| 232 | + } |
| 233 | + ) |
| 234 | + return config |
0 commit comments