@@ -528,7 +528,7 @@ def tearDown(self):
528528
529529 def test_stable_diffusion (self ):
530530 # make sure here that pndm scheduler skips prk
531- sd_pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1" )
531+ sd_pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1" , device_map = "auto" )
532532 sd_pipe = sd_pipe .to (torch_device )
533533 sd_pipe .set_progress_bar_config (disable = None )
534534
@@ -548,7 +548,7 @@ def test_stable_diffusion(self):
548548 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
549549
550550 def test_stable_diffusion_fast_ddim (self ):
551- sd_pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1" )
551+ sd_pipe = StableDiffusionPipeline .from_pretrained ("CompVis/stable-diffusion-v1-1" , device_map = "auto" )
552552 sd_pipe = sd_pipe .to (torch_device )
553553 sd_pipe .set_progress_bar_config (disable = None )
554554
@@ -576,7 +576,7 @@ def test_stable_diffusion_fast_ddim(self):
576576
577577 def test_lms_stable_diffusion_pipeline (self ):
578578 model_id = "CompVis/stable-diffusion-v1-1"
579- pipe = StableDiffusionPipeline .from_pretrained (model_id ).to (torch_device )
579+ pipe = StableDiffusionPipeline .from_pretrained (model_id , device_map = "auto" ).to (torch_device )
580580 pipe .set_progress_bar_config (disable = None )
581581 scheduler = LMSDiscreteScheduler .from_config (model_id , subfolder = "scheduler" )
582582 pipe .scheduler = scheduler
@@ -595,9 +595,10 @@ def test_lms_stable_diffusion_pipeline(self):
595595 def test_stable_diffusion_memory_chunking (self ):
596596 torch .cuda .reset_peak_memory_stats ()
597597 model_id = "CompVis/stable-diffusion-v1-4"
598- pipe = StableDiffusionPipeline .from_pretrained (model_id , revision = "fp16" , torch_dtype = torch . float16 ). to (
599- torch_device
598+ pipe = StableDiffusionPipeline .from_pretrained (
599+ model_id , revision = "fp16" , torch_dtype = torch . float16 , device_map = "auto"
600600 )
601+ pipe .to (torch_device )
601602 pipe .set_progress_bar_config (disable = None )
602603
603604 prompt = "a photograph of an astronaut riding a horse"
@@ -633,9 +634,10 @@ def test_stable_diffusion_memory_chunking(self):
633634 def test_stable_diffusion_text2img_pipeline_fp16 (self ):
634635 torch .cuda .reset_peak_memory_stats ()
635636 model_id = "CompVis/stable-diffusion-v1-4"
636- pipe = StableDiffusionPipeline .from_pretrained (model_id , revision = "fp16" , torch_dtype = torch . float16 ). to (
637- torch_device
637+ pipe = StableDiffusionPipeline .from_pretrained (
638+ model_id , revision = "fp16" , device_map = "auto" , torch_dtype = torch . float16
638639 )
640+ pipe = pipe .to (torch_device )
639641 pipe .set_progress_bar_config (disable = None )
640642
641643 prompt = "a photograph of an astronaut riding a horse"
@@ -670,6 +672,7 @@ def test_stable_diffusion_text2img_pipeline(self):
670672 pipe = StableDiffusionPipeline .from_pretrained (
671673 model_id ,
672674 safety_checker = None ,
675+ device_map = "auto" ,
673676 )
674677 pipe .to (torch_device )
675678 pipe .set_progress_bar_config (disable = None )
@@ -711,7 +714,7 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
711714 test_callback_fn .has_been_called = False
712715
713716 pipe = StableDiffusionPipeline .from_pretrained (
714- "CompVis/stable-diffusion-v1-4" , revision = "fp16" , torch_dtype = torch .float16
717+ "CompVis/stable-diffusion-v1-4" , revision = "fp16" , torch_dtype = torch .float16 , device_map = "auto"
715718 )
716719 pipe = pipe .to (torch_device )
717720 pipe .set_progress_bar_config (disable = None )
@@ -737,7 +740,7 @@ def test_stable_diffusion_accelerate_auto_device(self):
737740
738741 start_time = time .time ()
739742 pipeline_normal_load = StableDiffusionPipeline .from_pretrained (
740- pipeline_id , revision = "fp16" , torch_dtype = torch .float16 , use_auth_token = True
743+ pipeline_id , revision = "fp16" , torch_dtype = torch .float16 , device_map = "auto"
741744 )
742745 pipeline_normal_load .to (torch_device )
743746 normal_load_time = time .time () - start_time
@@ -758,7 +761,9 @@ def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
758761 pipeline_id = "CompVis/stable-diffusion-v1-4"
759762 prompt = "Andromeda galaxy in a bottle"
760763
761- pipeline = StableDiffusionPipeline .from_pretrained (pipeline_id , revision = "fp16" , torch_dtype = torch .float16 )
764+ pipeline = StableDiffusionPipeline .from_pretrained (
765+ pipeline_id , revision = "fp16" , torch_dtype = torch .float16 , device_map = "auto"
766+ )
762767 pipeline .enable_attention_slicing (1 )
763768 pipeline .enable_sequential_cpu_offload ()
764769
0 commit comments