Skip to content

Commit 71f56c7

Browse files
authored
Model tests xformers fixes (huggingface#5679)
* fix model xformers test * update
1 parent 6a89a6c commit 71f56c7

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

tests/models/test_modeling_common.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,16 @@ def test_set_xformers_attn_processor_for_determinism(self):
293293
with torch.no_grad():
294294
output_2 = model(**inputs_dict)[0]
295295

296+
model.set_attn_processor(XFormersAttnProcessor())
297+
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
298+
with torch.no_grad():
299+
output_3 = model(**inputs_dict)[0]
300+
301+
torch.use_deterministic_algorithms(True)
302+
296303
assert torch.allclose(output, output_2, atol=self.base_precision)
304+
assert torch.allclose(output, output_3, atol=self.base_precision)
305+
assert torch.allclose(output_2, output_3, atol=self.base_precision)
297306

298307
@require_torch_gpu
299308
def test_set_attn_processor_for_determinism(self):
@@ -315,11 +324,6 @@ def test_set_attn_processor_for_determinism(self):
315324
with torch.no_grad():
316325
output_2 = model(**inputs_dict)[0]
317326

318-
model.enable_xformers_memory_efficient_attention()
319-
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
320-
with torch.no_grad():
321-
model(**inputs_dict)[0]
322-
323327
model.set_attn_processor(AttnProcessor2_0())
324328
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
325329
with torch.no_grad():
@@ -330,18 +334,12 @@ def test_set_attn_processor_for_determinism(self):
330334
with torch.no_grad():
331335
output_5 = model(**inputs_dict)[0]
332336

333-
model.set_attn_processor(XFormersAttnProcessor())
334-
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
335-
with torch.no_grad():
336-
output_6 = model(**inputs_dict)[0]
337-
338337
torch.use_deterministic_algorithms(True)
339338

340339
# make sure that outputs match
341340
assert torch.allclose(output_2, output_1, atol=self.base_precision)
342341
assert torch.allclose(output_2, output_4, atol=self.base_precision)
343342
assert torch.allclose(output_2, output_5, atol=self.base_precision)
344-
assert torch.allclose(output_2, output_6, atol=self.base_precision)
345343

346344
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
347345
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)