18
18
import pytest
19
19
from absl .testing import parameterized
20
20
21
- from keras_nlp .models import AlbertClassifier
22
- from keras_nlp .models import BertClassifier
23
- from keras_nlp .models import RobertaClassifier
21
+ from keras_nlp .models .albert .albert_classifier import AlbertClassifier
22
+ from keras_nlp .models .backbone import Backbone
23
+ from keras_nlp .models .bert .bert_classifier import BertClassifier
24
+ from keras_nlp .models .roberta .roberta_classifier import RobertaClassifier
25
+ from keras_nlp .models .task import Task
24
26
from keras_nlp .tests .test_case import TestCase
25
- from keras_nlp .utils import preset_utils
27
+ from keras_nlp .utils .preset_utils import check_preset_class
28
+ from keras_nlp .utils .preset_utils import load_from_preset
29
+ from keras_nlp .utils .preset_utils import save_to_preset
26
30
27
31
28
32
class PresetUtilsTest (TestCase ):
@@ -36,7 +40,7 @@ class PresetUtilsTest(TestCase):
36
40
def test_preset_saving (self , cls , preset_name , tokenizer_type ):
37
41
save_dir = self .get_temp_dir ()
38
42
model = cls .from_preset (preset_name , num_classes = 2 )
39
- preset_utils . save_to_preset (model , save_dir )
43
+ save_to_preset (model , save_dir )
40
44
41
45
if tokenizer_type == "bytepair" :
42
46
vocab_filename = "assets/tokenizer/vocabulary.json"
@@ -72,7 +76,14 @@ def test_preset_saving(self, cls, preset_name, tokenizer_type):
72
76
self .assertEqual (config ["weights" ], "model.weights.h5" )
73
77
74
78
# Try loading the model from preset directory
75
- restored_model = preset_utils .load_from_preset (save_dir )
79
+ self .assertEqual (cls , check_preset_class (save_dir , cls ))
80
+ self .assertEqual (cls , check_preset_class (save_dir , Task ))
81
+ with self .assertRaises (ValueError ):
82
+ # Preset is a subclass of Task, not Backbone.
83
+ check_preset_class (save_dir , Backbone )
84
+
85
+ # Try loading the model from preset directory
86
+ restored_model = load_from_preset (save_dir )
76
87
77
88
train_data = (
78
89
["the quick brown fox." , "the slow brown fox." ], # Features.
0 commit comments