Skip to content

Commit 9cc8110

Browse files
committed
Use subclass checking check_preset_class (keras-team#1344)
Not currently needed for anything, just to keep in sync with KerasCV.
1 parent 7cc4323 commit 9cc8110

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

keras_nlp/utils/preset_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def check_preset_class(
203203
cls = keras.saving.get_registered_object(config["registered_name"])
204204
if not isinstance(classes, (tuple, list)):
205205
classes = (classes,)
206-
if cls not in classes:
206+
# Allow subclasses for testing a base class, e.g.
207+
# `check_preset_class(preset, Backbone)`
208+
if not any(issubclass(cls, x) for x in classes):
207209
raise ValueError(
208210
f"Unexpected class in preset `'{preset}'`. "
209211
"When calling `from_preset()` on a class object, the preset class "

keras_nlp/utils/preset_utils_test.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
import pytest
1919
from absl.testing import parameterized
2020

21-
from keras_nlp.models import AlbertClassifier
22-
from keras_nlp.models import BertClassifier
23-
from keras_nlp.models import RobertaClassifier
21+
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
22+
from keras_nlp.models.backbone import Backbone
23+
from keras_nlp.models.bert.bert_classifier import BertClassifier
24+
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
25+
from keras_nlp.models.task import Task
2426
from keras_nlp.tests.test_case import TestCase
25-
from keras_nlp.utils import preset_utils
27+
from keras_nlp.utils.preset_utils import check_preset_class
28+
from keras_nlp.utils.preset_utils import load_from_preset
29+
from keras_nlp.utils.preset_utils import save_to_preset
2630

2731

2832
class PresetUtilsTest(TestCase):
@@ -36,7 +40,7 @@ class PresetUtilsTest(TestCase):
3640
def test_preset_saving(self, cls, preset_name, tokenizer_type):
3741
save_dir = self.get_temp_dir()
3842
model = cls.from_preset(preset_name, num_classes=2)
39-
preset_utils.save_to_preset(model, save_dir)
43+
save_to_preset(model, save_dir)
4044

4145
if tokenizer_type == "bytepair":
4246
vocab_filename = "assets/tokenizer/vocabulary.json"
@@ -72,7 +76,14 @@ def test_preset_saving(self, cls, preset_name, tokenizer_type):
7276
self.assertEqual(config["weights"], "model.weights.h5")
7377

7478
# Try loading the model from preset directory
75-
restored_model = preset_utils.load_from_preset(save_dir)
79+
self.assertEqual(cls, check_preset_class(save_dir, cls))
80+
self.assertEqual(cls, check_preset_class(save_dir, Task))
81+
with self.assertRaises(ValueError):
82+
# Preset is a subclass of Task, not Backbone.
83+
check_preset_class(save_dir, Backbone)
84+
85+
# Try loading the model from preset directory
86+
restored_model = load_from_preset(save_dir)
7687

7788
train_data = (
7889
["the quick brown fox.", "the slow brown fox."], # Features.

0 commit comments

Comments
 (0)