Skip to content

Commit 577a6a6

Browse files
committed
Fix SD tests .to(device)
1 parent 21ceda3 commit 577a6a6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ def test_ldm_text2img_fast(self):
866866
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
867867
def test_stable_diffusion(self):
868868
# make sure here that pndm scheduler skips prk
869-
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
869+
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1").to(torch_device)
870870

871871
prompt = "A painting of a squirrel eating a burger"
872872
generator = torch.Generator(device=torch_device).manual_seed(0)
@@ -886,7 +886,7 @@ def test_stable_diffusion(self):
886886
@slow
887887
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
888888
def test_stable_diffusion_fast_ddim(self):
889-
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
889+
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1").to(torch_device)
890890

891891
scheduler = DDIMScheduler(
892892
beta_start=0.00085,
@@ -1003,8 +1003,8 @@ def test_karras_ve_pipeline(self):
10031003
@slow
10041004
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
10051005
def test_lms_stable_diffusion_pipeline(self):
1006-
model_id = "CompVis/stable-diffusion-v1-1-diffusers"
1007-
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
1006+
model_id = "CompVis/stable-diffusion-v1-1"
1007+
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
10081008
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
10091009
pipe.scheduler = scheduler
10101010

0 commit comments

Comments
 (0)