Skip to content

Commit 2997075

Browse files
Fast Tests on PR improvements: Batch Tests fixes (huggingface#5080)
* fix test * initial commit * change test * updates: * fix tests * test fix * test fix * fix tests * make test faster * clean up * fix precision in test * fix precision * Fix tests * Fix logging test * fix test * fix test --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent c2787c1 commit 2997075

20 files changed

+117
-208
lines changed

tests/pipelines/audioldm/test_audioldm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def test_attention_slicing_forward_pass(self):
359359
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
360360

361361
def test_inference_batch_single_identical(self):
362-
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
362+
self._test_inference_batch_single_identical()
363363

364364
@unittest.skipIf(
365365
torch_device != "cuda" or not is_xformers_available(),

tests/pipelines/audioldm2/test_audioldm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_dict_tuple_outputs_equivalent(self):
459459

460460
def test_inference_batch_single_identical(self):
461461
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
462-
self._test_inference_batch_single_identical(test_mean_pixel_difference=False, expected_max_diff=2e-4)
462+
self._test_inference_batch_single_identical(expected_max_diff=2e-4)
463463

464464
def test_save_load_local(self):
465465
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model

tests/pipelines/dit/test_dit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_inference(self):
9696
self.assertLessEqual(max_diff, 1e-3)
9797

9898
def test_inference_batch_single_identical(self):
99-
self._test_inference_batch_single_identical(relax_max_difference=True, expected_max_diff=1e-3)
99+
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
100100

101101
@unittest.skipIf(
102102
torch_device != "cuda" or not is_xformers_available(),

tests/pipelines/kandinsky/test_kandinsky_prior.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,7 @@ def test_kandinsky_prior(self):
224224

225225
@skip_mps
226226
def test_inference_batch_single_identical(self):
227-
test_max_difference = torch_device == "cpu"
228-
relax_max_difference = True
229-
test_mean_pixel_difference = False
230-
231-
self._test_inference_batch_single_identical(
232-
test_max_difference=test_max_difference,
233-
relax_max_difference=relax_max_difference,
234-
test_mean_pixel_difference=test_mean_pixel_difference,
235-
)
227+
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
236228

237229
@skip_mps
238230
def test_attention_slicing_forward_pass(self):

tests/pipelines/kandinsky_v22/test_kandinsky_prior.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,15 +224,7 @@ def test_kandinsky_prior(self):
224224

225225
@skip_mps
226226
def test_inference_batch_single_identical(self):
227-
test_max_difference = torch_device == "cpu"
228-
relax_max_difference = True
229-
test_mean_pixel_difference = False
230-
231-
self._test_inference_batch_single_identical(
232-
test_max_difference=test_max_difference,
233-
relax_max_difference=relax_max_difference,
234-
test_mean_pixel_difference=test_mean_pixel_difference,
235-
)
227+
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
236228

237229
@skip_mps
238230
def test_attention_slicing_forward_pass(self):

tests/pipelines/kandinsky_v22/test_kandinsky_prior_emb2emb.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,7 @@ def test_kandinsky_prior_emb2emb(self):
234234

235235
@skip_mps
236236
def test_inference_batch_single_identical(self):
237-
test_max_difference = torch_device == "cpu"
238-
relax_max_difference = True
239-
test_mean_pixel_difference = False
240-
241-
self._test_inference_batch_single_identical(
242-
test_max_difference=test_max_difference,
243-
relax_max_difference=relax_max_difference,
244-
test_mean_pixel_difference=test_mean_pixel_difference,
245-
)
237+
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
246238

247239
@skip_mps
248240
def test_attention_slicing_forward_pass(self):

tests/pipelines/musicldm/test_musicldm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_attention_slicing_forward_pass(self):
373373
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
374374

375375
def test_inference_batch_single_identical(self):
376-
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
376+
self._test_inference_batch_single_identical()
377377

378378
@unittest.skipIf(
379379
torch_device != "cuda" or not is_xformers_available(),

tests/pipelines/shap_e/test_shap_e.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4444

4545
@property
4646
def text_embedder_hidden_size(self):
47-
return 32
47+
return 16
4848

4949
@property
5050
def time_input_dim(self):
51-
return 32
51+
return 16
5252

5353
@property
5454
def time_embed_dim(self):
@@ -201,14 +201,7 @@ def test_inference_batch_consistent(self):
201201
self._test_inference_batch_consistent(batch_sizes=[1, 2])
202202

203203
def test_inference_batch_single_identical(self):
204-
test_max_difference = torch_device == "cpu"
205-
relax_max_difference = True
206-
207-
self._test_inference_batch_single_identical(
208-
batch_size=2,
209-
test_max_difference=test_max_difference,
210-
relax_max_difference=relax_max_difference,
211-
)
204+
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=6e-3)
212205

213206
def test_num_images_per_prompt(self):
214207
components = self.get_dummy_components()

tests/pipelines/shap_e/test_shap_e_img2img.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
5252

5353
@property
5454
def text_embedder_hidden_size(self):
55-
return 32
55+
return 16
5656

5757
@property
5858
def time_input_dim(self):
59-
return 32
59+
return 16
6060

6161
@property
6262
def time_embed_dim(self):
@@ -71,10 +71,10 @@ def dummy_image_encoder(self):
7171
torch.manual_seed(0)
7272
config = CLIPVisionConfig(
7373
hidden_size=self.text_embedder_hidden_size,
74-
image_size=64,
74+
image_size=32,
7575
projection_dim=self.text_embedder_hidden_size,
76-
intermediate_size=37,
77-
num_attention_heads=4,
76+
intermediate_size=24,
77+
num_attention_heads=2,
7878
num_channels=3,
7979
num_hidden_layers=5,
8080
patch_size=1,
@@ -170,7 +170,7 @@ def get_dummy_components(self):
170170
return components
171171

172172
def get_dummy_inputs(self, device, seed=0):
173-
input_image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
173+
input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
174174

175175
if str(device).startswith("mps"):
176176
generator = torch.manual_seed(seed)
@@ -219,15 +219,12 @@ def test_shap_e(self):
219219

220220
def test_inference_batch_consistent(self):
221221
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
222-
self._test_inference_batch_consistent(batch_sizes=[1, 2])
222+
self._test_inference_batch_consistent(batch_sizes=[2])
223223

224224
def test_inference_batch_single_identical(self):
225-
test_max_difference = torch_device == "cpu"
226-
relax_max_difference = True
227225
self._test_inference_batch_single_identical(
228226
batch_size=2,
229-
test_max_difference=test_max_difference,
230-
relax_max_difference=relax_max_difference,
227+
expected_max_diff=5e-3,
231228
)
232229

233230
def test_num_images_per_prompt(self):

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -499,14 +499,7 @@ def test_stable_diffusion_long_prompt(self):
499499
negative_prompt = None
500500
num_images_per_prompt = 1
501501
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
502-
503-
prompt = 25 * "@"
504-
with CaptureLogger(logger) as cap_logger_3:
505-
negative_text_embeddings_3, text_embeddings_3 = sd_pipe.encode_prompt(
506-
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
507-
)
508-
if negative_text_embeddings_3 is not None:
509-
text_embeddings_3 = torch.cat([negative_text_embeddings_3, text_embeddings_3])
502+
logger.setLevel(logging.WARNING)
510503

511504
prompt = 100 * "@"
512505
with CaptureLogger(logger) as cap_logger:
@@ -516,6 +509,9 @@ def test_stable_diffusion_long_prompt(self):
516509
if negative_text_embeddings is not None:
517510
text_embeddings = torch.cat([negative_text_embeddings, text_embeddings])
518511

512+
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
513+
assert cap_logger.out.count("@") == 25
514+
519515
negative_prompt = "Hello"
520516
with CaptureLogger(logger) as cap_logger_2:
521517
negative_text_embeddings_2, text_embeddings_2 = sd_pipe.encode_prompt(
@@ -524,12 +520,18 @@ def test_stable_diffusion_long_prompt(self):
524520
if negative_text_embeddings_2 is not None:
525521
text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2])
526522

523+
assert cap_logger.out == cap_logger_2.out
524+
525+
prompt = 25 * "@"
526+
with CaptureLogger(logger) as cap_logger_3:
527+
negative_text_embeddings_3, text_embeddings_3 = sd_pipe.encode_prompt(
528+
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
529+
)
530+
if negative_text_embeddings_3 is not None:
531+
text_embeddings_3 = torch.cat([negative_text_embeddings_3, text_embeddings_3])
532+
527533
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
528534
assert text_embeddings.shape[1] == 77
529-
530-
assert cap_logger.out == cap_logger_2.out
531-
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
532-
assert cap_logger.out.count("@") == 25
533535
assert cap_logger_3.out == ""
534536

535537
def test_stable_diffusion_height_width_opt(self):

0 commit comments

Comments
 (0)