Skip to content

Commit 66fd3ec

Browse files
authored
[CI] try to fix GPU OOMs between tests and excessive tqdm logging (huggingface#323)
* Fix tqdm and OOM * tqdm auto * tqdm is still spamming try to disable it altogether * rather just set the pipe config, to keep the global tqdm clean * style
1 parent 3a536ac commit 66fd3ec

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

tests/test_pipelines.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import random
1718
import tempfile
1819
import unittest
@@ -77,6 +78,12 @@ def test_progress_bar(capsys):
7778

7879

7980
class PipelineFastTests(unittest.TestCase):
81+
def tearDown(self):
82+
# clean up the VRAM after each test
83+
super().tearDown()
84+
gc.collect()
85+
torch.cuda.empty_cache()
86+
8087
@property
8188
def dummy_image(self):
8289
batch_size = 1
@@ -186,6 +193,7 @@ def test_ddim(self):
186193

187194
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
188195
ddpm.to(torch_device)
196+
ddpm.set_progress_bar_config(disable=None)
189197

190198
generator = torch.manual_seed(0)
191199
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
@@ -204,6 +212,7 @@ def test_pndm_cifar10(self):
204212

205213
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
206214
pndm.to(torch_device)
215+
pndm.set_progress_bar_config(disable=None)
207216
generator = torch.manual_seed(0)
208217
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
209218

@@ -222,6 +231,7 @@ def test_ldm_text2img(self):
222231

223232
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
224233
ldm.to(torch_device)
234+
ldm.set_progress_bar_config(disable=None)
225235

226236
prompt = "A painting of a squirrel eating a burger"
227237
generator = torch.manual_seed(0)
@@ -261,6 +271,7 @@ def test_stable_diffusion_ddim(self):
261271
feature_extractor=self.dummy_extractor,
262272
)
263273
sd_pipe = sd_pipe.to(device)
274+
sd_pipe.set_progress_bar_config(disable=None)
264275

265276
prompt = "A painting of a squirrel eating a burger"
266277
generator = torch.Generator(device=device).manual_seed(0)
@@ -293,6 +304,7 @@ def test_stable_diffusion_pndm(self):
293304
feature_extractor=self.dummy_extractor,
294305
)
295306
sd_pipe = sd_pipe.to(device)
307+
sd_pipe.set_progress_bar_config(disable=None)
296308

297309
prompt = "A painting of a squirrel eating a burger"
298310
generator = torch.Generator(device=device).manual_seed(0)
@@ -325,6 +337,7 @@ def test_stable_diffusion_k_lms(self):
325337
feature_extractor=self.dummy_extractor,
326338
)
327339
sd_pipe = sd_pipe.to(device)
340+
sd_pipe.set_progress_bar_config(disable=None)
328341

329342
prompt = "A painting of a squirrel eating a burger"
330343
generator = torch.Generator(device=device).manual_seed(0)
@@ -344,6 +357,7 @@ def test_score_sde_ve_pipeline(self):
344357

345358
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
346359
sde_ve.to(torch_device)
360+
sde_ve.set_progress_bar_config(disable=None)
347361

348362
torch.manual_seed(0)
349363
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
@@ -362,6 +376,7 @@ def test_ldm_uncond(self):
362376

363377
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
364378
ldm.to(torch_device)
379+
ldm.set_progress_bar_config(disable=None)
365380

366381
generator = torch.manual_seed(0)
367382
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
@@ -378,6 +393,7 @@ def test_karras_ve_pipeline(self):
378393

379394
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
380395
pipe.to(torch_device)
396+
pipe.set_progress_bar_config(disable=None)
381397

382398
generator = torch.manual_seed(0)
383399
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
@@ -408,6 +424,7 @@ def test_stable_diffusion_img2img(self):
408424
feature_extractor=self.dummy_extractor,
409425
)
410426
sd_pipe = sd_pipe.to(device)
427+
sd_pipe.set_progress_bar_config(disable=None)
411428

412429
prompt = "A painting of a squirrel eating a burger"
413430
generator = torch.Generator(device=device).manual_seed(0)
@@ -451,6 +468,7 @@ def test_stable_diffusion_inpaint(self):
451468
feature_extractor=self.dummy_extractor,
452469
)
453470
sd_pipe = sd_pipe.to(device)
471+
sd_pipe.set_progress_bar_config(disable=None)
454472

455473
prompt = "A painting of a squirrel eating a burger"
456474
generator = torch.Generator(device=device).manual_seed(0)
@@ -474,6 +492,12 @@ def test_stable_diffusion_inpaint(self):
474492

475493

476494
class PipelineTesterMixin(unittest.TestCase):
495+
def tearDown(self):
496+
# clean up the VRAM after each test
497+
super().tearDown()
498+
gc.collect()
499+
torch.cuda.empty_cache()
500+
477501
def test_from_pretrained_save_pretrained(self):
478502
# 1. Load models
479503
model = UNet2DModel(
@@ -489,6 +513,7 @@ def test_from_pretrained_save_pretrained(self):
489513

490514
ddpm = DDPMPipeline(model, schedular)
491515
ddpm.to(torch_device)
516+
ddpm.set_progress_bar_config(disable=None)
492517

493518
with tempfile.TemporaryDirectory() as tmpdirname:
494519
ddpm.save_pretrained(tmpdirname)
@@ -511,8 +536,10 @@ def test_from_pretrained_hub(self):
511536

512537
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
513538
ddpm.to(torch_device)
539+
ddpm.set_progress_bar_config(disable=None)
514540
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
515541
ddpm_from_hub.to(torch_device)
542+
ddpm_from_hub.set_progress_bar_config(disable=None)
516543

517544
generator = torch.manual_seed(0)
518545

@@ -532,9 +559,11 @@ def test_from_pretrained_hub_pass_model(self):
532559
unet = UNet2DModel.from_pretrained(model_path)
533560
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
534561
ddpm_from_hub_custom_model.to(torch_device)
562+
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
535563

536564
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
537565
ddpm_from_hub.to(torch_device)
566+
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
538567

539568
generator = torch.manual_seed(0)
540569

@@ -550,6 +579,7 @@ def test_output_format(self):
550579

551580
pipe = DDIMPipeline.from_pretrained(model_path)
552581
pipe.to(torch_device)
582+
pipe.set_progress_bar_config(disable=None)
553583

554584
generator = torch.manual_seed(0)
555585
images = pipe(generator=generator, output_type="numpy")["sample"]
@@ -576,6 +606,7 @@ def test_ddpm_cifar10(self):
576606

577607
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
578608
ddpm.to(torch_device)
609+
ddpm.set_progress_bar_config(disable=None)
579610

580611
generator = torch.manual_seed(0)
581612
image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -595,6 +626,7 @@ def test_ddim_lsun(self):
595626

596627
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
597628
ddpm.to(torch_device)
629+
ddpm.set_progress_bar_config(disable=None)
598630

599631
generator = torch.manual_seed(0)
600632
image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -614,6 +646,7 @@ def test_ddim_cifar10(self):
614646

615647
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
616648
ddim.to(torch_device)
649+
ddim.set_progress_bar_config(disable=None)
617650

618651
generator = torch.manual_seed(0)
619652
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
@@ -633,6 +666,7 @@ def test_pndm_cifar10(self):
633666

634667
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
635668
pndm.to(torch_device)
669+
pndm.set_progress_bar_config(disable=None)
636670
generator = torch.manual_seed(0)
637671
image = pndm(generator=generator, output_type="numpy")["sample"]
638672

@@ -646,6 +680,7 @@ def test_pndm_cifar10(self):
646680
def test_ldm_text2img(self):
647681
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
648682
ldm.to(torch_device)
683+
ldm.set_progress_bar_config(disable=None)
649684

650685
prompt = "A painting of a squirrel eating a burger"
651686
generator = torch.manual_seed(0)
@@ -663,6 +698,7 @@ def test_ldm_text2img(self):
663698
def test_ldm_text2img_fast(self):
664699
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
665700
ldm.to(torch_device)
701+
ldm.set_progress_bar_config(disable=None)
666702

667703
prompt = "A painting of a squirrel eating a burger"
668704
generator = torch.manual_seed(0)
@@ -680,6 +716,7 @@ def test_stable_diffusion(self):
680716
# make sure here that pndm scheduler skips prk
681717
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
682718
sd_pipe = sd_pipe.to(torch_device)
719+
sd_pipe.set_progress_bar_config(disable=None)
683720

684721
prompt = "A painting of a squirrel eating a burger"
685722
generator = torch.Generator(device=torch_device).manual_seed(0)
@@ -701,6 +738,7 @@ def test_stable_diffusion(self):
701738
def test_stable_diffusion_fast_ddim(self):
702739
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
703740
sd_pipe = sd_pipe.to(torch_device)
741+
sd_pipe.set_progress_bar_config(disable=None)
704742

705743
scheduler = DDIMScheduler(
706744
beta_start=0.00085,
@@ -733,6 +771,7 @@ def test_score_sde_ve_pipeline(self):
733771

734772
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
735773
sde_ve.to(torch_device)
774+
sde_ve.set_progress_bar_config(disable=None)
736775

737776
torch.manual_seed(0)
738777
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
@@ -748,6 +787,7 @@ def test_score_sde_ve_pipeline(self):
748787
def test_ldm_uncond(self):
749788
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
750789
ldm.to(torch_device)
790+
ldm.set_progress_bar_config(disable=None)
751791

752792
generator = torch.manual_seed(0)
753793
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
@@ -768,8 +808,10 @@ def test_ddpm_ddim_equality(self):
768808

769809
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
770810
ddpm.to(torch_device)
811+
ddpm.set_progress_bar_config(disable=None)
771812
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
772813
ddim.to(torch_device)
814+
ddim.set_progress_bar_config(disable=None)
773815

774816
generator = torch.manual_seed(0)
775817
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -790,9 +832,11 @@ def test_ddpm_ddim_equality_batched(self):
790832

791833
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
792834
ddpm.to(torch_device)
835+
ddpm.set_progress_bar_config(disable=None)
793836

794837
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
795838
ddim.to(torch_device)
839+
ddim.set_progress_bar_config(disable=None)
796840

797841
generator = torch.manual_seed(0)
798842
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
@@ -813,6 +857,7 @@ def test_karras_ve_pipeline(self):
813857

814858
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
815859
pipe.to(torch_device)
860+
pipe.set_progress_bar_config(disable=None)
816861

817862
generator = torch.manual_seed(0)
818863
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
@@ -827,6 +872,7 @@ def test_karras_ve_pipeline(self):
827872
def test_lms_stable_diffusion_pipeline(self):
828873
model_id = "CompVis/stable-diffusion-v1-1"
829874
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
875+
pipe.set_progress_bar_config(disable=None)
830876
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
831877
pipe.scheduler = scheduler
832878

@@ -852,6 +898,7 @@ def test_stable_diffusion_img2img_pipeline(self):
852898
model_id = "CompVis/stable-diffusion-v1-4"
853899
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, use_auth_token=True)
854900
pipe.to(torch_device)
901+
pipe.set_progress_bar_config(disable=None)
855902

856903
prompt = "A fantasy landscape, trending on artstation"
857904

@@ -878,6 +925,7 @@ def test_stable_diffusion_in_paint_pipeline(self):
878925
model_id = "CompVis/stable-diffusion-v1-4"
879926
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)
880927
pipe.to(torch_device)
928+
pipe.set_progress_bar_config(disable=None)
881929

882930
prompt = "A red cat sitting on a parking bench"
883931

0 commit comments

Comments
 (0)