Skip to content

Commit d2d9764

Browse files
[Tests] Speed up slow tests (huggingface#1040)
* [Tests] Speed up slow tests * Up * up
1 parent a80480f commit d2d9764

File tree

12 files changed

+58
-36
lines changed

12 files changed

+58
-36
lines changed

tests/pipelines/dance_diffusion/test_dance_diffusion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def tearDown(self):
8686
def test_dance_diffusion(self):
8787
device = torch_device
8888

89-
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
89+
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
9090
pipe = pipe.to(device)
9191
pipe.set_progress_bar_config(disable=None)
9292

@@ -103,7 +103,9 @@ def test_dance_diffusion(self):
103103
def test_dance_diffusion_fp16(self):
104104
device = torch_device
105105

106-
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
106+
pipe = DanceDiffusionPipeline.from_pretrained(
107+
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
108+
)
107109
pipe = pipe.to(device)
108110
pipe.set_progress_bar_config(disable=None)
109111

tests/pipelines/ddim/test_ddim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
7878
def test_inference_ema_bedroom(self):
7979
model_id = "google/ddpm-ema-bedroom-256"
8080

81-
unet = UNet2DModel.from_pretrained(model_id)
81+
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
8282
scheduler = DDIMScheduler.from_config(model_id)
8383

8484
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
@@ -97,7 +97,7 @@ def test_inference_ema_bedroom(self):
9797
def test_inference_cifar10(self):
9898
model_id = "google/ddpm-cifar10-32"
9999

100-
unet = UNet2DModel.from_pretrained(model_id)
100+
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
101101
scheduler = DDIMScheduler()
102102

103103
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)

tests/pipelines/ddpm/test_ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
3838
def test_inference_cifar10(self):
3939
model_id = "google/ddpm-cifar10-32"
4040

41-
unet = UNet2DModel.from_pretrained(model_id)
41+
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
4242
scheduler = DDPMScheduler.from_config(model_id)
4343

4444
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)

tests/pipelines/karras_ve/test_karras_ve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_inference(self):
7070
class KarrasVePipelineIntegrationTests(unittest.TestCase):
7171
def test_inference(self):
7272
model_id = "google/ncsnpp-celebahq-256"
73-
model = UNet2DModel.from_pretrained(model_id)
73+
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
7474
scheduler = KarrasVeScheduler()
7575

7676
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

tests/pipelines/latent_diffusion/test_latent_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_inference_text2img(self):
121121
@require_torch
122122
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
123123
def test_inference_text2img(self):
124-
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
124+
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
125125
ldm.to(torch_device)
126126
ldm.set_progress_bar_config(disable=None)
127127

@@ -138,7 +138,7 @@ def test_inference_text2img(self):
138138
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
139139

140140
def test_inference_text2img_fast(self):
141-
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
141+
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
142142
ldm.to(torch_device)
143143
ldm.set_progress_bar_config(disable=None)
144144

tests/pipelines/pndm/test_pndm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
7171
def test_inference_cifar10(self):
7272
model_id = "google/ddpm-cifar10-32"
7373

74-
unet = UNet2DModel.from_pretrained(model_id)
74+
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
7575
scheduler = PNDMScheduler()
7676

7777
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)

tests/pipelines/score_sde_ve/test_score_sde_ve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_inference(self):
7272
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
7373
def test_inference(self):
7474
model_id = "google/ncsnpp-church-256"
75-
model = UNet2DModel.from_pretrained(model_id)
75+
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
7676

7777
scheduler = ScoreSdeVeScheduler.from_config(model_id)
7878

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def tearDown(self):
528528

529529
def test_stable_diffusion(self):
530530
# make sure here that pndm scheduler skips prk
531-
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
531+
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
532532
sd_pipe = sd_pipe.to(torch_device)
533533
sd_pipe.set_progress_bar_config(disable=None)
534534

@@ -548,7 +548,7 @@ def test_stable_diffusion(self):
548548
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
549549

550550
def test_stable_diffusion_fast_ddim(self):
551-
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
551+
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
552552
sd_pipe = sd_pipe.to(torch_device)
553553
sd_pipe.set_progress_bar_config(disable=None)
554554

@@ -576,7 +576,7 @@ def test_stable_diffusion_fast_ddim(self):
576576

577577
def test_lms_stable_diffusion_pipeline(self):
578578
model_id = "CompVis/stable-diffusion-v1-1"
579-
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
579+
pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device)
580580
pipe.set_progress_bar_config(disable=None)
581581
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
582582
pipe.scheduler = scheduler
@@ -595,9 +595,10 @@ def test_lms_stable_diffusion_pipeline(self):
595595
def test_stable_diffusion_memory_chunking(self):
596596
torch.cuda.reset_peak_memory_stats()
597597
model_id = "CompVis/stable-diffusion-v1-4"
598-
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to(
599-
torch_device
598+
pipe = StableDiffusionPipeline.from_pretrained(
599+
model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
600600
)
601+
pipe.to(torch_device)
601602
pipe.set_progress_bar_config(disable=None)
602603

603604
prompt = "a photograph of an astronaut riding a horse"
@@ -633,9 +634,10 @@ def test_stable_diffusion_memory_chunking(self):
633634
def test_stable_diffusion_text2img_pipeline_fp16(self):
634635
torch.cuda.reset_peak_memory_stats()
635636
model_id = "CompVis/stable-diffusion-v1-4"
636-
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to(
637-
torch_device
637+
pipe = StableDiffusionPipeline.from_pretrained(
638+
model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
638639
)
640+
pipe = pipe.to(torch_device)
639641
pipe.set_progress_bar_config(disable=None)
640642

641643
prompt = "a photograph of an astronaut riding a horse"
@@ -670,6 +672,7 @@ def test_stable_diffusion_text2img_pipeline(self):
670672
pipe = StableDiffusionPipeline.from_pretrained(
671673
model_id,
672674
safety_checker=None,
675+
device_map="auto",
673676
)
674677
pipe.to(torch_device)
675678
pipe.set_progress_bar_config(disable=None)
@@ -711,7 +714,7 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
711714
test_callback_fn.has_been_called = False
712715

713716
pipe = StableDiffusionPipeline.from_pretrained(
714-
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
717+
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
715718
)
716719
pipe = pipe.to(torch_device)
717720
pipe.set_progress_bar_config(disable=None)
@@ -737,7 +740,7 @@ def test_stable_diffusion_accelerate_auto_device(self):
737740

738741
start_time = time.time()
739742
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
740-
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
743+
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
741744
)
742745
pipeline_normal_load.to(torch_device)
743746
normal_load_time = time.time() - start_time
@@ -758,7 +761,9 @@ def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
758761
pipeline_id = "CompVis/stable-diffusion-v1-4"
759762
prompt = "Andromeda galaxy in a bottle"
760763

761-
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
764+
pipeline = StableDiffusionPipeline.from_pretrained(
765+
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
766+
)
762767
pipeline.enable_attention_slicing(1)
763768
pipeline.enable_sequential_cpu_offload()
764769

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def test_stable_diffusion_img2img_pipeline(self):
488488
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
489489
model_id,
490490
safety_checker=None,
491+
device_map="auto",
491492
)
492493
pipe.to(torch_device)
493494
pipe.set_progress_bar_config(disable=None)
@@ -529,6 +530,7 @@ def test_stable_diffusion_img2img_pipeline_k_lms(self):
529530
model_id,
530531
scheduler=lms,
531532
safety_checker=None,
533+
device_map="auto",
532534
)
533535
pipe.to(torch_device)
534536
pipe.set_progress_bar_config(disable=None)
@@ -580,7 +582,7 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
580582
init_image = init_image.resize((768, 512))
581583

582584
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
583-
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
585+
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
584586
)
585587
pipe.to(torch_device)
586588
pipe.set_progress_bar_config(disable=None)

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def test_stable_diffusion_inpaint_pipeline(self):
288288
pipe = StableDiffusionInpaintPipeline.from_pretrained(
289289
model_id,
290290
safety_checker=None,
291+
device_map="auto",
291292
)
292293
pipe.to(torch_device)
293294
pipe.set_progress_bar_config(disable=None)
@@ -329,6 +330,7 @@ def test_stable_diffusion_inpaint_pipeline_fp16(self):
329330
revision="fp16",
330331
torch_dtype=torch.float16,
331332
safety_checker=None,
333+
device_map="auto",
332334
)
333335
pipe.to(torch_device)
334336
pipe.set_progress_bar_config(disable=None)
@@ -366,7 +368,9 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self):
366368

367369
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
368370
model_id = "runwayml/stable-diffusion-inpainting"
369-
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
371+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
372+
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
373+
)
370374
pipe.to(torch_device)
371375
pipe.set_progress_bar_config(disable=None)
372376
pipe.enable_attention_slicing()

0 commit comments

Comments
 (0)