Skip to content

Commit 07e1cc2

Browse files
Add a lora dense layer (keras-team#1263)
* Add a lora dense layer Co-authored-by: Abheesht <[email protected]> * address comments * Fix merge conflict * minor fix * another einsum restriction * Last doc nit from Ian --------- Co-authored-by: Abheesht <[email protected]>
1 parent 8cab8ef commit 07e1cc2

File tree

4 files changed

+378
-2
lines changed

4 files changed

+378
-2
lines changed

keras_nlp/conftest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,23 @@ def pytest_collection_modifyitems(config, items):
6969
not run_extra_large_tests,
7070
reason="need --run_extra_large option to run",
7171
)
72-
skip_tf_only = pytest.mark.skipif(
72+
tf_only = pytest.mark.skipif(
7373
not backend_config.backend() == "tensorflow",
7474
reason="tests only run on tf backend",
7575
)
76+
multi_backend_only = pytest.mark.skipif(
77+
not backend_config.multi_backend(),
78+
reason="tests only run on with multi-backend keras",
79+
)
7680
for item in items:
7781
if "large" in item.keywords:
7882
item.add_marker(skip_large)
7983
if "extra_large" in item.keywords:
8084
item.add_marker(skip_extra_large)
8185
if "tf_only" in item.keywords:
82-
item.add_marker(skip_tf_only)
86+
item.add_marker(tf_only)
87+
if "multi_backend_only" in item.keywords:
88+
item.add_marker(multi_backend_only)
8389

8490

8591
# Disable traceback filtering for quicker debugging of tests failures.

keras_nlp/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
CachedMultiHeadAttention,
1717
)
1818
from keras_nlp.layers.modeling.f_net_encoder import FNetEncoder
19+
from keras_nlp.layers.modeling.lora_dense import LoraDense
1920
from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead
2021
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
2122
from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import re
15+
16+
from keras_nlp.api_export import keras_nlp_export
17+
from keras_nlp.backend import config
18+
from keras_nlp.backend import keras
19+
from keras_nlp.backend import ops
20+
21+
22+
def validate_einsum_equation(equation):
23+
# For simplicity, we greatly restrict possible einsum equations. The final
24+
# axis of the input must be the first axis of our kernel, and must not
25+
# appear in our output.
26+
left, right, output = re.split(",|->", equation)
27+
valid = (
28+
left[-1] == right[0]
29+
and left[-1] not in output
30+
and set(left[:-1]).isdisjoint(set(right[1:]))
31+
)
32+
if not valid:
33+
raise ValueError(
34+
"When passing a `EinsumDense` layer to a `LoraDense` layer, the "
35+
"einsum `equation` must always have the form `*x,x*->*`, where "
36+
"each `*` can be any sequence. Conceptually, the `equation` should "
37+
"always represent a dense matmul on the last axis of the input. "
38+
f"Received invalid equation `'{equation}'`."
39+
)
40+
41+
42+
@keras_nlp_export("keras_nlp.layers.LoraDense")
43+
class LoraDense(keras.layers.Layer):
44+
"""A LoRA adapter layer for a dense input layer.
45+
46+
This layer implements a low-rank decomposition of a dense transformation, as
47+
described in [LoRA: Low-Rank Adaptation Of Large Language Models](https://arxiv.org/pdf/2106.09685.pdf)
48+
This layer can be used to replace a dense layer with a layer whose
49+
parameters are mostly frozen.
50+
51+
By default, this layer takes in an `inner_dense` layer, freezes its
52+
parameters, and builds a low-rank decomposed update to sum with the original
53+
`inner_dense` output. These update parameters can be merged back into the
54+
`inner_dense` kernel by calling `merge_weights()`.
55+
56+
Args:
57+
inner_dense: A `keras.layers.Dense` or `keras.layers.EinsumDense`.
58+
The inner dense layer to freeze and wrap with the `LoraDense`
59+
layer. Note that for `EinsumDense` layers, the einsum equation must
60+
represent a dense transformation on the last axis of the input,
61+
though adding new axes to the output (e.g. a multi-head axis) is
62+
allowed.
63+
rank: int. The inner rank of the decomposed dense transformation. The
64+
lower this number, the fewer trainable parameters the layer will
65+
have.
66+
alpha: float. A constant value used for scaling the lora update. The
67+
lora update to the original dense transformation will be scaled by
68+
`alpha / rank`.
69+
lora_a_initializer: The initializer to use for the inner projection
70+
from layer inputs to the inner `rank` intermediate outputs.
71+
freeze_kernel: If true, the kernel of the inner dense layer will have
72+
`trainable` set to `False`.
73+
freeze_bias: If true, the kernel of the inner dense layer will have
74+
`trainable` set to `False`.
75+
**kwargs: other keyword arguments.
76+
77+
Examples:
78+
79+
Wrap a `Dense` layer.
80+
```python
81+
batch_size, feature_size = 4, 16
82+
rank = 4
83+
inputs = np.random.uniform(size=(batch_size, feature_size))
84+
inner_dense = keras.layers.Dense(feature_size)
85+
lora_dense = keras_nlp.layers.LoraDense(inner_dense, rank=4)
86+
# Output with inner dense begins equal.
87+
assert np.allclose(inner_dense(inputs), lora_dense(inputs))
88+
89+
# Add some random updates to the lora parameters.
90+
lora_dense.lora_a.assign(np.random.uniform(size=(feature_size, rank)))
91+
lora_dense.lora_b.assign(np.random.uniform(size=(rank, feature_size)))
92+
assert not np.allclose(inner_dense(inputs), lora_dense(inputs))
93+
94+
# Merge the lora dense and output
95+
lora_dense.merge_weights()
96+
assert np.allclose(inner_dense(inputs), lora_dense(inputs))
97+
```
98+
99+
Wrap an `EinsumDense` layer with a multi-head projection.
100+
```python
101+
batch_size, sequence_length, feature_size = 4, 10, 16
102+
num_heads = 2
103+
rank = 4
104+
inputs = np.random.uniform(size=(batch_size, sequence_length, feature_size))
105+
inner_dense = keras.layers.EinsumDense(
106+
"abc,cde->abde",
107+
output_shape=(sequence_length, num_heads, feature_size // num_heads),
108+
)
109+
lora_dense = keras_nlp.layers.LoraDense(inner_dense, rank=4)
110+
# Output shape (4, 10, 2, 8)
111+
lora_dense(inputs)
112+
```
113+
"""
114+
115+
def __init__(
116+
self,
117+
inner_dense,
118+
rank=8,
119+
alpha=8.0,
120+
lora_a_initializer="variance_scaling",
121+
freeze_kernel=True,
122+
freeze_bias=True,
123+
**kwargs,
124+
):
125+
# Default to the same dtype as our inner layer.
126+
if "dtype" not in kwargs:
127+
kwargs["dtype"] = inner_dense.dtype_policy
128+
super().__init__(**kwargs)
129+
130+
if not config.multi_backend():
131+
raise ValueError(
132+
"Lora only works with multi-backend Keras 3. Please set the "
133+
"`KERAS_BACKEND` environment variable to use this API."
134+
)
135+
136+
if isinstance(inner_dense, keras.layers.Dense):
137+
self.inner_dense = inner_dense
138+
elif isinstance(inner_dense, keras.layers.EinsumDense):
139+
self.inner_dense = inner_dense
140+
validate_einsum_equation(inner_dense.equation)
141+
else:
142+
raise ValueError(
143+
"Only `Dense` and `EinsumDense` inner layers are supported. "
144+
f"Received: inner_dense={inner_dense}"
145+
)
146+
147+
self.rank = rank
148+
self.alpha = alpha
149+
self.scale = alpha / rank
150+
self.freeze_kernel = freeze_kernel
151+
self.freeze_bias = freeze_bias
152+
self.lora_a_initializer = keras.initializers.get(lora_a_initializer)
153+
154+
if inner_dense.built:
155+
self.build_from_config(inner_dense.get_build_config())
156+
157+
def build(self, inputs_shape):
158+
if not self.inner_dense.built:
159+
self.inner_dense.build(inputs_shape)
160+
161+
if self.freeze_kernel and self.inner_dense.kernel is not None:
162+
self.inner_dense.kernel.trainable = False
163+
164+
if self.freeze_bias and self.inner_dense.bias is not None:
165+
self.inner_dense.bias.trainable = False
166+
167+
input_dim = inputs_shape[-1]
168+
self.lora_a = self.add_weight(
169+
name="lora_a",
170+
shape=(input_dim, self.rank),
171+
initializer=self.lora_a_initializer,
172+
)
173+
kernel_shape = self.inner_dense.kernel.shape
174+
self.lora_b = self.add_weight(
175+
name="lora_b",
176+
shape=(self.rank,) + kernel_shape[1:],
177+
initializer="zeros",
178+
)
179+
self.built = True
180+
181+
def merge_weights(self):
182+
"""Merge lora updates into the wrapped dense layer.
183+
184+
This function should only be called outside of any compiled context
185+
(e.g. not during `fit()`, `predict()` or `evaluate()`). It will merge
186+
the updates from the lora layers into the original dense layer, and
187+
re-initialize the lora variables.
188+
"""
189+
if not self.built:
190+
return
191+
192+
# Compute matmul of lora_a and lora_b to get a kernel sized update.
193+
update = ops.tensordot(self.lora_a, self.lora_b, axes=([-1], [0]))
194+
update = update * ops.cast(self.scale, update.dtype)
195+
# Add lora updates back into the inner dense kernel.
196+
self.inner_dense.kernel.assign_add(update)
197+
# Re-initialize lora weights.
198+
self.lora_a.assign(
199+
self.lora_a_initializer(self.lora_a.shape, self.lora_a.dtype)
200+
)
201+
self.lora_b.assign(ops.zeros_like(self.lora_b))
202+
203+
def call(self, inputs):
204+
original_output = self.inner_dense(inputs)
205+
# Compute the low-rank intermediate output.
206+
update = ops.matmul(inputs, self.lora_a)
207+
# Use the matching dense computation for a Dense or EinsumDense.
208+
if isinstance(self.inner_dense, keras.layers.Dense):
209+
update = ops.matmul(update, self.lora_b)
210+
else:
211+
update = ops.einsum(self.inner_dense.equation, update, self.lora_b)
212+
# Scale and sum the lora update with the original frozen output.
213+
return original_output + update * ops.cast(self.scale, update.dtype)
214+
215+
@classmethod
216+
def from_config(cls, config):
217+
config["inner_dense"] = keras.layers.deserialize(config["inner_dense"])
218+
return super().from_config(config)
219+
220+
def get_config(self):
221+
config = super().get_config()
222+
config.update(
223+
{
224+
"inner_dense": keras.layers.serialize(self.inner_dense),
225+
"rank": self.rank,
226+
"alpha": self.alpha,
227+
"lora_a_initializer": keras.initializers.serialize(
228+
self.lora_a_initializer
229+
),
230+
"freeze_kernel": self.freeze_kernel,
231+
"freeze_bias": self.freeze_bias,
232+
}
233+
)
234+
return config

0 commit comments

Comments
 (0)