1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import gc
1617import random
1718import tempfile
1819import unittest
@@ -77,6 +78,12 @@ def test_progress_bar(capsys):
7778
7879
7980class PipelineFastTests (unittest .TestCase ):
81+ def tearDown (self ):
82+ # clean up the VRAM after each test
83+ super ().tearDown ()
84+ gc .collect ()
85+ torch .cuda .empty_cache ()
86+
8087 @property
8188 def dummy_image (self ):
8289 batch_size = 1
@@ -186,6 +193,7 @@ def test_ddim(self):
186193
187194 ddpm = DDIMPipeline (unet = unet , scheduler = scheduler )
188195 ddpm .to (torch_device )
196+ ddpm .set_progress_bar_config (disable = None )
189197
190198 generator = torch .manual_seed (0 )
191199 image = ddpm (generator = generator , num_inference_steps = 2 , output_type = "numpy" )["sample" ]
@@ -204,6 +212,7 @@ def test_pndm_cifar10(self):
204212
205213 pndm = PNDMPipeline (unet = unet , scheduler = scheduler )
206214 pndm .to (torch_device )
215+ pndm .set_progress_bar_config (disable = None )
207216 generator = torch .manual_seed (0 )
208217 image = pndm (generator = generator , num_inference_steps = 20 , output_type = "numpy" )["sample" ]
209218
@@ -222,6 +231,7 @@ def test_ldm_text2img(self):
222231
223232 ldm = LDMTextToImagePipeline (vqvae = vae , bert = bert , tokenizer = tokenizer , unet = unet , scheduler = scheduler )
224233 ldm .to (torch_device )
234+ ldm .set_progress_bar_config (disable = None )
225235
226236 prompt = "A painting of a squirrel eating a burger"
227237 generator = torch .manual_seed (0 )
@@ -261,6 +271,7 @@ def test_stable_diffusion_ddim(self):
261271 feature_extractor = self .dummy_extractor ,
262272 )
263273 sd_pipe = sd_pipe .to (device )
274+ sd_pipe .set_progress_bar_config (disable = None )
264275
265276 prompt = "A painting of a squirrel eating a burger"
266277 generator = torch .Generator (device = device ).manual_seed (0 )
@@ -293,6 +304,7 @@ def test_stable_diffusion_pndm(self):
293304 feature_extractor = self .dummy_extractor ,
294305 )
295306 sd_pipe = sd_pipe .to (device )
307+ sd_pipe .set_progress_bar_config (disable = None )
296308
297309 prompt = "A painting of a squirrel eating a burger"
298310 generator = torch .Generator (device = device ).manual_seed (0 )
@@ -325,6 +337,7 @@ def test_stable_diffusion_k_lms(self):
325337 feature_extractor = self .dummy_extractor ,
326338 )
327339 sd_pipe = sd_pipe .to (device )
340+ sd_pipe .set_progress_bar_config (disable = None )
328341
329342 prompt = "A painting of a squirrel eating a burger"
330343 generator = torch .Generator (device = device ).manual_seed (0 )
@@ -344,6 +357,7 @@ def test_score_sde_ve_pipeline(self):
344357
345358 sde_ve = ScoreSdeVePipeline (unet = unet , scheduler = scheduler )
346359 sde_ve .to (torch_device )
360+ sde_ve .set_progress_bar_config (disable = None )
347361
348362 torch .manual_seed (0 )
349363 image = sde_ve (num_inference_steps = 2 , output_type = "numpy" )["sample" ]
@@ -362,6 +376,7 @@ def test_ldm_uncond(self):
362376
363377 ldm = LDMPipeline (unet = unet , vqvae = vae , scheduler = scheduler )
364378 ldm .to (torch_device )
379+ ldm .set_progress_bar_config (disable = None )
365380
366381 generator = torch .manual_seed (0 )
367382 image = ldm (generator = generator , num_inference_steps = 2 , output_type = "numpy" )["sample" ]
@@ -378,6 +393,7 @@ def test_karras_ve_pipeline(self):
378393
379394 pipe = KarrasVePipeline (unet = unet , scheduler = scheduler )
380395 pipe .to (torch_device )
396+ pipe .set_progress_bar_config (disable = None )
381397
382398 generator = torch .manual_seed (0 )
383399 image = pipe (num_inference_steps = 2 , generator = generator , output_type = "numpy" )["sample" ]
@@ -408,6 +424,7 @@ def test_stable_diffusion_img2img(self):
408424 feature_extractor = self .dummy_extractor ,
409425 )
410426 sd_pipe = sd_pipe .to (device )
427+ sd_pipe .set_progress_bar_config (disable = None )
411428
412429 prompt = "A painting of a squirrel eating a burger"
413430 generator = torch .Generator (device = device ).manual_seed (0 )
@@ -451,6 +468,7 @@ def test_stable_diffusion_inpaint(self):
451468 feature_extractor = self .dummy_extractor ,
452469 )
453470 sd_pipe = sd_pipe .to (device )
471+ sd_pipe .set_progress_bar_config (disable = None )
454472
455473 prompt = "A painting of a squirrel eating a burger"
456474 generator = torch .Generator (device = device ).manual_seed (0 )
@@ -474,6 +492,12 @@ def test_stable_diffusion_inpaint(self):
474492
475493
476494class PipelineTesterMixin (unittest .TestCase ):
495+ def tearDown (self ):
496+ # clean up the VRAM after each test
497+ super ().tearDown ()
498+ gc .collect ()
499+ torch .cuda .empty_cache ()
500+
477501 def test_from_pretrained_save_pretrained (self ):
478502 # 1. Load models
479503 model = UNet2DModel (
@@ -489,6 +513,7 @@ def test_from_pretrained_save_pretrained(self):
489513
490514 ddpm = DDPMPipeline (model , schedular )
491515 ddpm .to (torch_device )
516+ ddpm .set_progress_bar_config (disable = None )
492517
493518 with tempfile .TemporaryDirectory () as tmpdirname :
494519 ddpm .save_pretrained (tmpdirname )
@@ -511,8 +536,10 @@ def test_from_pretrained_hub(self):
511536
512537 ddpm = DDPMPipeline .from_pretrained (model_path , scheduler = scheduler )
513538 ddpm .to (torch_device )
539+ ddpm .set_progress_bar_config (disable = None )
514540 ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path , scheduler = scheduler )
515541 ddpm_from_hub .to (torch_device )
542+ ddpm_from_hub .set_progress_bar_config (disable = None )
516543
517544 generator = torch .manual_seed (0 )
518545
@@ -532,9 +559,11 @@ def test_from_pretrained_hub_pass_model(self):
532559 unet = UNet2DModel .from_pretrained (model_path )
533560 ddpm_from_hub_custom_model = DiffusionPipeline .from_pretrained (model_path , unet = unet , scheduler = scheduler )
534561 ddpm_from_hub_custom_model .to (torch_device )
562+ ddpm_from_hub_custom_model .set_progress_bar_config (disable = None )
535563
536564 ddpm_from_hub = DiffusionPipeline .from_pretrained (model_path , scheduler = scheduler )
537565 ddpm_from_hub .to (torch_device )
566+ ddpm_from_hub_custom_model .set_progress_bar_config (disable = None )
538567
539568 generator = torch .manual_seed (0 )
540569
@@ -550,6 +579,7 @@ def test_output_format(self):
550579
551580 pipe = DDIMPipeline .from_pretrained (model_path )
552581 pipe .to (torch_device )
582+ pipe .set_progress_bar_config (disable = None )
553583
554584 generator = torch .manual_seed (0 )
555585 images = pipe (generator = generator , output_type = "numpy" )["sample" ]
@@ -576,6 +606,7 @@ def test_ddpm_cifar10(self):
576606
577607 ddpm = DDPMPipeline (unet = unet , scheduler = scheduler )
578608 ddpm .to (torch_device )
609+ ddpm .set_progress_bar_config (disable = None )
579610
580611 generator = torch .manual_seed (0 )
581612 image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -595,6 +626,7 @@ def test_ddim_lsun(self):
595626
596627 ddpm = DDIMPipeline (unet = unet , scheduler = scheduler )
597628 ddpm .to (torch_device )
629+ ddpm .set_progress_bar_config (disable = None )
598630
599631 generator = torch .manual_seed (0 )
600632 image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -614,6 +646,7 @@ def test_ddim_cifar10(self):
614646
615647 ddim = DDIMPipeline (unet = unet , scheduler = scheduler )
616648 ddim .to (torch_device )
649+ ddim .set_progress_bar_config (disable = None )
617650
618651 generator = torch .manual_seed (0 )
619652 image = ddim (generator = generator , eta = 0.0 , output_type = "numpy" )["sample" ]
@@ -633,6 +666,7 @@ def test_pndm_cifar10(self):
633666
634667 pndm = PNDMPipeline (unet = unet , scheduler = scheduler )
635668 pndm .to (torch_device )
669+ pndm .set_progress_bar_config (disable = None )
636670 generator = torch .manual_seed (0 )
637671 image = pndm (generator = generator , output_type = "numpy" )["sample" ]
638672
@@ -646,6 +680,7 @@ def test_pndm_cifar10(self):
646680 def test_ldm_text2img (self ):
647681 ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
648682 ldm .to (torch_device )
683+ ldm .set_progress_bar_config (disable = None )
649684
650685 prompt = "A painting of a squirrel eating a burger"
651686 generator = torch .manual_seed (0 )
@@ -663,6 +698,7 @@ def test_ldm_text2img(self):
663698 def test_ldm_text2img_fast (self ):
664699 ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
665700 ldm .to (torch_device )
701+ ldm .set_progress_bar_config (disable = None )
666702
667703 prompt = "A painting of a squirrel eating a burger"
668704 generator = torch .manual_seed (0 )
@@ -680,6 +716,7 @@ def test_stable_diffusion(self):
680716 # make sure here that pndm scheduler skips prk
681717 sd_pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1" , use_auth_token = True )
682718 sd_pipe = sd_pipe .to (torch_device )
719+ sd_pipe .set_progress_bar_config (disable = None )
683720
684721 prompt = "A painting of a squirrel eating a burger"
685722 generator = torch .Generator (device = torch_device ).manual_seed (0 )
@@ -701,6 +738,7 @@ def test_stable_diffusion(self):
701738 def test_stable_diffusion_fast_ddim (self ):
702739 sd_pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1" , use_auth_token = True )
703740 sd_pipe = sd_pipe .to (torch_device )
741+ sd_pipe .set_progress_bar_config (disable = None )
704742
705743 scheduler = DDIMScheduler (
706744 beta_start = 0.00085 ,
@@ -733,6 +771,7 @@ def test_score_sde_ve_pipeline(self):
733771
734772 sde_ve = ScoreSdeVePipeline (unet = model , scheduler = scheduler )
735773 sde_ve .to (torch_device )
774+ sde_ve .set_progress_bar_config (disable = None )
736775
737776 torch .manual_seed (0 )
738777 image = sde_ve (num_inference_steps = 300 , output_type = "numpy" )["sample" ]
@@ -748,6 +787,7 @@ def test_score_sde_ve_pipeline(self):
748787 def test_ldm_uncond (self ):
749788 ldm = LDMPipeline .from_pretrained ("CompVis/ldm-celebahq-256" )
750789 ldm .to (torch_device )
790+ ldm .set_progress_bar_config (disable = None )
751791
752792 generator = torch .manual_seed (0 )
753793 image = ldm (generator = generator , num_inference_steps = 5 , output_type = "numpy" )["sample" ]
@@ -768,8 +808,10 @@ def test_ddpm_ddim_equality(self):
768808
769809 ddpm = DDPMPipeline (unet = unet , scheduler = ddpm_scheduler )
770810 ddpm .to (torch_device )
811+ ddpm .set_progress_bar_config (disable = None )
771812 ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
772813 ddim .to (torch_device )
814+ ddim .set_progress_bar_config (disable = None )
773815
774816 generator = torch .manual_seed (0 )
775817 ddpm_image = ddpm (generator = generator , output_type = "numpy" )["sample" ]
@@ -790,9 +832,11 @@ def test_ddpm_ddim_equality_batched(self):
790832
791833 ddpm = DDPMPipeline (unet = unet , scheduler = ddpm_scheduler )
792834 ddpm .to (torch_device )
835+ ddpm .set_progress_bar_config (disable = None )
793836
794837 ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
795838 ddim .to (torch_device )
839+ ddim .set_progress_bar_config (disable = None )
796840
797841 generator = torch .manual_seed (0 )
798842 ddpm_images = ddpm (batch_size = 4 , generator = generator , output_type = "numpy" )["sample" ]
@@ -813,6 +857,7 @@ def test_karras_ve_pipeline(self):
813857
814858 pipe = KarrasVePipeline (unet = model , scheduler = scheduler )
815859 pipe .to (torch_device )
860+ pipe .set_progress_bar_config (disable = None )
816861
817862 generator = torch .manual_seed (0 )
818863 image = pipe (num_inference_steps = 20 , generator = generator , output_type = "numpy" )["sample" ]
@@ -827,6 +872,7 @@ def test_karras_ve_pipeline(self):
827872 def test_lms_stable_diffusion_pipeline (self ):
828873 model_id = "CompVis/stable-diffusion-v1-1"
829874 pipe = StableDiffusionPipeline .from_pretrained (model_id , use_auth_token = True ).to (torch_device )
875+ pipe .set_progress_bar_config (disable = None )
830876 scheduler = LMSDiscreteScheduler .from_config (model_id , subfolder = "scheduler" , use_auth_token = True )
831877 pipe .scheduler = scheduler
832878
@@ -852,6 +898,7 @@ def test_stable_diffusion_img2img_pipeline(self):
852898 model_id = "CompVis/stable-diffusion-v1-4"
853899 pipe = StableDiffusionImg2ImgPipeline .from_pretrained (model_id , use_auth_token = True )
854900 pipe .to (torch_device )
901+ pipe .set_progress_bar_config (disable = None )
855902
856903 prompt = "A fantasy landscape, trending on artstation"
857904
@@ -878,6 +925,7 @@ def test_stable_diffusion_in_paint_pipeline(self):
878925 model_id = "CompVis/stable-diffusion-v1-4"
879926 pipe = StableDiffusionInpaintPipeline .from_pretrained (model_id , use_auth_token = True )
880927 pipe .to (torch_device )
928+ pipe .set_progress_bar_config (disable = None )
881929
882930 prompt = "A red cat sitting on a parking bench"
883931
0 commit comments