@@ -195,7 +195,7 @@ class ModelTesterMixin:
195195 main_input_name = None # overwrite in model specific tester class
196196 base_precision = 1e-3
197197
198- def test_from_save_pretrained (self ):
198+ def test_from_save_pretrained (self , expected_max_diff = 5e-5 ):
199199 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
200200
201201 model = self .model_class (** init_dict )
@@ -221,8 +221,8 @@ def test_from_save_pretrained(self):
221221 if isinstance (new_image , dict ):
222222 new_image = new_image .to_tuple ()[0 ]
223223
224- max_diff = (image - new_image ).abs ().sum ().item ()
225- self .assertLessEqual (max_diff , 5e-5 , "Models give different forward passes" )
224+ max_diff = (image - new_image ).abs ().max ().item ()
225+ self .assertLessEqual (max_diff , expected_max_diff , "Models give different forward passes" )
226226
227227 def test_getattr_is_correct (self ):
228228 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -316,7 +316,7 @@ def test_set_attn_processor_for_determinism(self):
316316 assert torch .allclose (output_2 , output_5 , atol = self .base_precision )
317317 assert torch .allclose (output_2 , output_6 , atol = self .base_precision )
318318
319- def test_from_save_pretrained_variant (self ):
319+ def test_from_save_pretrained_variant (self , expected_max_diff = 5e-5 ):
320320 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
321321
322322 model = self .model_class (** init_dict )
@@ -351,8 +351,8 @@ def test_from_save_pretrained_variant(self):
351351 if isinstance (new_image , dict ):
352352 new_image = new_image .to_tuple ()[0 ]
353353
354- max_diff = (image - new_image ).abs ().sum ().item ()
355- self .assertLessEqual (max_diff , 5e-5 , "Models give different forward passes" )
354+ max_diff = (image - new_image ).abs ().max ().item ()
355+ self .assertLessEqual (max_diff , expected_max_diff , "Models give different forward passes" )
356356
357357 @require_torch_2
358358 def test_from_save_pretrained_dynamo (self ):
0 commit comments