Skip to content

Commit b6a6e27

Browse files
authored
Simplify running KerasNLP with Keras 3 (keras-team#1308)
* Simplify running KerasNLP with Keras 3 We should not land this until Keras 3, TensorFlow 2.15, and keras-nlp-nightly are released. * Address comments * Tweaks * Add link * fix link
1 parent 36a62a6 commit b6a6e27

File tree

12 files changed

+76
-116
lines changed

12 files changed

+76
-116
lines changed

README.md

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/keras-team/keras-nlp/issues)
55

66
KerasNLP is a natural language processing library that works natively
7-
with TensorFlow, JAX, or PyTorch. Built on [multi-backend Keras](https://keras.io/keras_core/announcement/)
8-
(Keras 3), these models, layers, metrics, and tokenizers can be trained and
9-
serialized in any framework and re-used in another without costly migrations.
7+
with TensorFlow, JAX, or PyTorch. Built on Keras 3, these models, layers,
8+
metrics, and tokenizers can be trained and serialized in any framework and
9+
re-used in another without costly migrations.
1010

1111
KerasNLP supports users through their entire development cycle. Our workflows
1212
are built from modular components that have state-of-the-art preset weights when
@@ -40,27 +40,51 @@ to start learning our API. We welcome [contributions](CONTRIBUTING.md).
4040

4141
## Installation
4242

43-
To install the latest official release:
43+
KerasNLP supports both Keras 2 and Keras 3. We recommend Keras 3 for all new
44+
users, as it enables using KerasNLP models and layers with JAX, TensorFlow and
45+
PyTorch.
46+
47+
### Keras 2 Installation
48+
49+
To install the latest KerasNLP release with Keras 2, simply run:
50+
51+
```
52+
pip install --upgrade keras-nlp
53+
```
54+
55+
### Keras 3 Installation
56+
57+
There are currently two ways to install Keras 3 with KerasNLP. To install the
58+
stable versions of KerasNLP and Keras 3, you should install Keras 3 **after**
59+
installing KerasNLP. This is a temporary step while TensorFlow is pinned to
60+
Keras 2, and will no longer be necessary after TensorFlow 2.16.
4461

4562
```
46-
pip install keras-nlp --upgrade
63+
pip install --upgrade keras-nlp
64+
pip install --upgrade keras>=3
4765
```
4866

49-
To install the latest unreleased changes to the library, we recommend using
50-
pip to install directly from the master branch on github:
67+
To install the latest nightly changes for both KerasNLP and Keras, you can use
68+
our nightly package.
5169

5270
```
53-
pip install git+https://github.com/keras-team/keras-nlp.git --upgrade
71+
pip install --upgrade keras-nlp-nightly
5472
```
5573

74+
> [!IMPORTANT]
75+
> Keras 3 will not function with TensorFlow 2.14 or earlier.
76+
77+
Read [Getting started with Keras](https://keras.io/getting_started/) for more information
78+
on installing Keras 3 and compatibility with different frameworks.
79+
5680
## Quickstart
5781

5882
Fine-tune BERT on a small sentiment analysis task using the
5983
[`keras_nlp.models`](https://keras.io/api/keras_nlp/models/) API:
6084

6185
```python
6286
import os
63-
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow", or "torch".
87+
os.environ["KERAS_BACKEND"] = "tensorflow" # Or "jax" or "torch"!
6488

6589
import keras_nlp
6690
import tensorflow_datasets as tfds
@@ -87,14 +111,9 @@ For more in depth guides and examples, visit https://keras.io/keras_nlp/.
87111

88112
## Configuring your backend
89113

90-
**Keras 3** is an upcoming release of the Keras library which supports
91-
TensorFlow, Jax or Torch as backends. This is supported today in KerasNLP,
92-
but will not be enabled by default until the official release of Keras 3. If you
93-
`pip install keras-nlp` and run a script or notebook without changes, you will
94-
be using TensorFlow and **Keras 2**.
95-
96-
If you would like to enable a preview of the Keras 3 behavior, you can do
97-
so by setting the `KERAS_BACKEND` environment variable. For example:
114+
If you have Keras 3 installed in your environment (see installation above),
115+
you can use KerasNLP with any of JAX, TensorFlow and PyTorch. To do so, set the
116+
`KERAS_BACKEND` environment variable. For example:
98117

99118
```shell
100119
export KERAS_BACKEND=jax
@@ -113,16 +132,6 @@ import keras_nlp
113132
> Make sure to set the `KERAS_BACKEND` before import any Keras libraries, it
114133
> will be used to set up Keras when it is first imported.
115134
116-
Until the Keras 3 release, KerasNLP will use a preview of Keras 3 on PyPI named
117-
[keras-core](https://pypi.org/project/keras-core/).
118-
119-
> [!IMPORTANT]
120-
> If you set `KERAS_BACKEND` variable, you should `import keras_core as keras`
121-
> instead of `import keras`. This is a temporary step until Keras 3 is out!
122-
123-
To restore the default **Keras 2** behavior, `unset KERAS_BACKEND` before
124-
importing Keras and KerasNLP.
125-
126135
## Compatibility
127136

128137
We follow [Semantic Versioning](https://semver.org/), and plan to

keras_nlp/backend/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
"""
1515
Keras backend module.
1616
17-
This module adds a temporarily Keras API surface that is fully under KerasNLP
18-
control. This allows us to switch between `keras_core` and `tf.keras`, as well
19-
as add shims to support older version of `tf.keras`.
17+
This module adds a temporary Keras API surface that is fully under KerasNLP
18+
control. The goal is to allow us to write Keras 3-like code everywhere, while
19+
still supporting Keras 2. We do this by using the `keras_core` package with
20+
Keras 2 to backport Keras 3 numerics APIs (`keras.ops` and `keras.random`) into
21+
Keras 2. The sub-modules exposed are as follows:
2022
21-
- `config`: check which backend is being run.
22-
- `keras`: The full `keras` API (via `keras_core` or `tf.keras`).
23-
- `ops`: `keras_core.ops`, always tf backed if using `tf.keras`.
24-
- `random`: `keras_core.random`, always tf backed if using `tf.keras`.
23+
- `config`: check which version of Keras is being run.
24+
- `keras`: The full `keras` API with compat shims for older Keras versions.
25+
- `ops`: `keras.ops` for Keras 3 or `keras_core.ops` for Keras 2.
26+
- `random`: `keras.random` for Keras 3 or `keras_core.ops` for Keras 2.
2527
"""
2628

2729
from keras_nlp.backend import config

keras_nlp/backend/config.py

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,56 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import json
1615
import os
1716

18-
_MULTI_BACKEND = False
19-
_USE_KERAS_3 = False
20-
21-
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
22-
# Otherwise either ~/.keras or /tmp.
23-
if "KERAS_HOME" in os.environ:
24-
_keras_dir = os.environ.get("KERAS_HOME")
25-
else:
26-
_keras_base_dir = os.path.expanduser("~")
27-
if not os.access(_keras_base_dir, os.W_OK):
28-
_keras_base_dir = "/tmp"
29-
_keras_dir = os.path.join(_keras_base_dir, ".keras")
30-
31-
# Attempt to read KerasNLP config file.
32-
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_nlp.json"))
33-
if os.path.exists(_config_path):
34-
try:
35-
with open(_config_path) as f:
36-
_config = json.load(f)
37-
except ValueError:
38-
_config = {}
39-
_MULTI_BACKEND = _config.get("multi_backend", _MULTI_BACKEND)
40-
41-
# Save config file, if possible.
42-
if not os.path.exists(_keras_dir):
43-
try:
44-
os.makedirs(_keras_dir)
45-
except OSError:
46-
# Except permission denied and potential race conditions
47-
# in multi-threaded environments.
48-
pass
49-
50-
if not os.path.exists(_config_path):
51-
_config = {
52-
"multi_backend": _MULTI_BACKEND,
53-
}
54-
try:
55-
with open(_config_path, "w") as f:
56-
f.write(json.dumps(_config, indent=4))
57-
except IOError:
58-
# Except permission denied.
59-
pass
60-
61-
# If KERAS_BACKEND is set in the environment use multi-backend keras.
62-
if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]:
63-
_MULTI_BACKEND = True
64-
6517

6618
def detect_if_tensorflow_uses_keras_3():
6719
# We follow the version of keras that tensorflow is configured to use.
@@ -84,29 +36,28 @@ def detect_if_tensorflow_uses_keras_3():
8436

8537

8638
_USE_KERAS_3 = detect_if_tensorflow_uses_keras_3()
87-
if _USE_KERAS_3:
88-
_MULTI_BACKEND = True
39+
40+
if not _USE_KERAS_3:
41+
backend = os.environ.get("KERAS_BACKEND")
42+
if backend and backend != "tensorflow":
43+
raise RuntimeError(
44+
"When running Keras 2, the `KERAS_BACKEND` environment variable "
45+
f"must either be unset or `'tensorflow'`. Received: `{backend}`. "
46+
"To set another backend, please install Keras 3. See "
47+
"https://github.com/keras-team/keras-nlp#installation"
48+
)
8949

9050

9151
def keras_3():
9252
"""Check if Keras 3 is being used."""
9353
return _USE_KERAS_3
9454

9555

96-
def multi_backend():
97-
"""Check if multi-backend Keras is enabled."""
98-
return _MULTI_BACKEND
99-
100-
10156
def backend():
10257
"""Check the backend framework."""
103-
if not multi_backend():
104-
return "tensorflow"
10558
if not keras_3():
106-
import keras_core
107-
108-
return keras_core.config.backend()
59+
return "tensorflow"
10960

110-
from tensorflow import keras
61+
import keras
11162

11263
return keras.config.backend()

keras_nlp/backend/keras.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
if config.keras_3():
2222
from keras import * # noqa: F403, F401
23-
elif config.multi_backend():
24-
from keras_core import * # noqa: F403, F401
2523
else:
2624
from tensorflow.keras import * # noqa: F403, F401
2725

keras_nlp/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def pytest_collection_modifyitems(config, items):
7373
not backend_config.backend() == "tensorflow",
7474
reason="tests only run on tf backend",
7575
)
76-
multi_backend_only = pytest.mark.skipif(
77-
not backend_config.multi_backend(),
76+
keras_3_only = pytest.mark.skipif(
77+
not backend_config.keras_3(),
7878
reason="tests only run on with multi-backend keras",
7979
)
8080
for item in items:
@@ -84,11 +84,11 @@ def pytest_collection_modifyitems(config, items):
8484
item.add_marker(skip_extra_large)
8585
if "tf_only" in item.keywords:
8686
item.add_marker(tf_only)
87-
if "multi_backend_only" in item.keywords:
88-
item.add_marker(multi_backend_only)
87+
if "keras_3_only" in item.keywords:
88+
item.add_marker(keras_3_only)
8989

9090

9191
# Disable traceback filtering for quicker debugging of tests failures.
9292
tf.debugging.disable_traceback_filtering()
93-
if backend_config.multi_backend():
93+
if backend_config.keras_3():
9494
keras.config.disable_traceback_filtering()

keras_nlp/layers/modeling/cached_multi_head_attention_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_layer_behaviors(self):
3636
expected_output_shape=(2, 4, 6),
3737
expected_num_trainable_weights=8,
3838
expected_num_non_trainable_variables=1,
39-
# tf.keras does not handle mixed precision correctly when not set
39+
# Keras 2 does not handle mixed precision correctly when not set
4040
# globally.
41-
run_mixed_precision_check=config.multi_backend(),
41+
run_mixed_precision_check=config.keras_3(),
4242
)
4343

4444
def test_cache_call_is_correct(self):

keras_nlp/layers/modeling/lora_dense.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def __init__(
127127
kwargs["dtype"] = inner_dense.dtype_policy
128128
super().__init__(**kwargs)
129129

130-
if not config.multi_backend():
130+
if not config.keras_3():
131131
raise ValueError(
132-
"Lora only works with multi-backend Keras 3. Please set the "
133-
"`KERAS_BACKEND` environment variable to use this API."
132+
"Lora requires with Keras 3, but Keras 2 is installed. Please "
133+
"see https://github.com/keras-team/keras-nlp#installation"
134134
)
135135

136136
if isinstance(inner_dense, keras.layers.Dense):

keras_nlp/layers/modeling/lora_dense_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from keras_nlp.tests.test_case import TestCase
2121

2222

23-
@pytest.mark.multi_backend_only
23+
@pytest.mark.keras_3_only
2424
class LoraDenseTest(TestCase):
2525
def test_layer_behaviors(self):
2626
self.run_layer_test(

keras_nlp/models/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def bold_text(x):
319319
print_fn(console.end_capture(), line_break=False)
320320

321321
# Avoid `tf.keras.Model.summary()`, so the above output matches.
322-
if config.multi_backend():
322+
if config.keras_3():
323323
super().summary(
324324
line_length=line_length,
325325
positions=positions,

keras_nlp/tests/test_case.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def call(self, x):
148148
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
149149
model.fit(input_data, output_data, verbose=0)
150150

151-
if config.multi_backend():
151+
if config.keras_3():
152152
# Build test.
153153
layer = cls(**init_kwargs)
154154
if isinstance(input_data, dict):
@@ -253,8 +253,8 @@ def run_serialization_test(self, instance):
253253
revived_cfg = revived_instance.get_config()
254254
revived_cfg_json = json.dumps(revived_cfg, sort_keys=True, indent=4)
255255
self.assertEqual(cfg_json, revived_cfg_json)
256-
# Dir tests only work on keras-core.
257-
if config.multi_backend():
256+
# Dir tests only work with Keras 3.
257+
if config.keras_3():
258258
self.assertEqual(ref_dir, dir(revived_instance))
259259

260260
# serialization roundtrip
@@ -266,8 +266,8 @@ def run_serialization_test(self, instance):
266266
revived_cfg = revived_instance.get_config()
267267
revived_cfg_json = json.dumps(revived_cfg, sort_keys=True, indent=4)
268268
self.assertEqual(cfg_json, revived_cfg_json)
269-
# Dir tests only work on keras-core.
270-
if config.multi_backend():
269+
# Dir tests only work with Keras 3.
270+
if config.keras_3():
271271
new_dir = dir(revived_instance)[:]
272272
for lst in [ref_dir, new_dir]:
273273
if "__annotations__" in lst:

keras_nlp/utils/tensor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def is_tensor_type(x):
153153

154154

155155
def standardize_dtype(dtype):
156-
if config.multi_backend():
156+
if config.keras_3():
157157
return keras.backend.standardize_dtype(dtype)
158158
if hasattr(dtype, "name"):
159159
return dtype.name

tools/checkpoint_conversion/convert_t5_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from absl import app
2121
from absl import flags
2222
from checkpoint_conversion_utils import get_md5_checksum
23-
from keras_core import ops
23+
from keras import ops
2424

2525
import keras_nlp
2626

0 commit comments

Comments
 (0)