@@ -293,7 +293,16 @@ def test_set_xformers_attn_processor_for_determinism(self):
293293 with torch .no_grad ():
294294 output_2 = model (** inputs_dict )[0 ]
295295
296+ model .set_attn_processor (XFormersAttnProcessor ())
297+ assert all (type (proc ) == XFormersAttnProcessor for proc in model .attn_processors .values ())
298+ with torch .no_grad ():
299+ output_3 = model (** inputs_dict )[0 ]
300+
301+ torch .use_deterministic_algorithms (True )
302+
296303 assert torch .allclose (output , output_2 , atol = self .base_precision )
304+ assert torch .allclose (output , output_3 , atol = self .base_precision )
305+ assert torch .allclose (output_2 , output_3 , atol = self .base_precision )
297306
298307 @require_torch_gpu
299308 def test_set_attn_processor_for_determinism (self ):
@@ -315,11 +324,6 @@ def test_set_attn_processor_for_determinism(self):
315324 with torch .no_grad ():
316325 output_2 = model (** inputs_dict )[0 ]
317326
318- model .enable_xformers_memory_efficient_attention ()
319- assert all (type (proc ) == XFormersAttnProcessor for proc in model .attn_processors .values ())
320- with torch .no_grad ():
321- model (** inputs_dict )[0 ]
322-
323327 model .set_attn_processor (AttnProcessor2_0 ())
324328 assert all (type (proc ) == AttnProcessor2_0 for proc in model .attn_processors .values ())
325329 with torch .no_grad ():
@@ -330,18 +334,12 @@ def test_set_attn_processor_for_determinism(self):
330334 with torch .no_grad ():
331335 output_5 = model (** inputs_dict )[0 ]
332336
333- model .set_attn_processor (XFormersAttnProcessor ())
334- assert all (type (proc ) == XFormersAttnProcessor for proc in model .attn_processors .values ())
335- with torch .no_grad ():
336- output_6 = model (** inputs_dict )[0 ]
337-
338337 torch .use_deterministic_algorithms (True )
339338
340339 # make sure that outputs match
341340 assert torch .allclose (output_2 , output_1 , atol = self .base_precision )
342341 assert torch .allclose (output_2 , output_4 , atol = self .base_precision )
343342 assert torch .allclose (output_2 , output_5 , atol = self .base_precision )
344- assert torch .allclose (output_2 , output_6 , atol = self .base_precision )
345343
346344 def test_from_save_pretrained_variant (self , expected_max_diff = 5e-5 ):
347345 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
0 commit comments