@@ -182,6 +182,36 @@ def test_load_pretrained(
182
182
assert loaded_model .config == mock_config
183
183
184
184
185
+ def test_load_pretrained_dim (
186
+ tmp_path : Path , mock_vectors : np .ndarray , mock_tokenizer : Tokenizer , mock_config : dict [str , str ]
187
+ ) -> None :
188
+ """Test loading a pretrained model after saving it."""
189
+ # Save the model to a temporary path
190
+ model = StaticModel (vectors = mock_vectors , tokenizer = mock_tokenizer , config = mock_config )
191
+ save_path = tmp_path / "saved_model"
192
+ model .save_pretrained (save_path )
193
+
194
+ # Load the model back from the same path
195
+ loaded_model = StaticModel .from_pretrained (save_path , dimensionality = 2 )
196
+
197
+ # Assert that the loaded model has the same properties as the original one
198
+ np .testing .assert_array_equal (loaded_model .embedding , mock_vectors [:, :2 ])
199
+ assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
200
+ assert loaded_model .config == mock_config
201
+
202
+ # Load the model back from the same path
203
+ loaded_model = StaticModel .from_pretrained (save_path , dimensionality = None )
204
+
205
+ # Assert that the loaded model has the same properties as the original one
206
+ np .testing .assert_array_equal (loaded_model .embedding , mock_vectors )
207
+ assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
208
+ assert loaded_model .config == mock_config
209
+
210
+ # Load the model back from the same path
211
+ with pytest .raises (ValueError ):
212
+ StaticModel .from_pretrained (save_path , dimensionality = 3000 )
213
+
214
+
185
215
def test_initialize_normalize (mock_vectors : np .ndarray , mock_tokenizer : Tokenizer ) -> None :
186
216
"""Tests whether the normalization initialization is correct."""
187
217
model = StaticModel (mock_vectors , mock_tokenizer , {}, normalize = None )
0 commit comments