Skip to content

Commit 813744e

Browse files
authored
MPS schedulers: don't use float64 (huggingface#1169)
* Schedulers: don't use float64 on mps * Test set_timesteps() on device (float schedulers). * SD pipeline: use device in set_timesteps. * SD in-painting pipeline: use device in set_timesteps. * Tests: fix mps crashes. * Skip test_load_pipeline_from_git on mps. Not compatible with float16. * Use device.type instead of str in Euler schedulers.
1 parent 5a8b356 commit 813744e

File tree

8 files changed

+117
-22
lines changed

8 files changed

+117
-22
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,9 @@ def __call__(
360360
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
361361
latents = latents.to(self.device)
362362

363-
# set timesteps
364-
self.scheduler.set_timesteps(num_inference_steps)
365-
366-
# Some schedulers like PNDM have timesteps as arrays
367-
# It's more optimized to move all timesteps to correct device beforehand
368-
timesteps_tensor = self.scheduler.timesteps.to(self.device)
363+
# set timesteps and move to the correct device
364+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
365+
timesteps_tensor = self.scheduler.timesteps
369366

370367
# scale the initial noise by the standard deviation required by the scheduler
371368
latents = latents * self.scheduler.init_noise_sigma

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,9 @@ def __call__(
416416
" `pipeline.unet` or your `mask_image` or `image` input."
417417
)
418418

419-
# set timesteps
420-
self.scheduler.set_timesteps(num_inference_steps)
421-
422-
# Some schedulers like PNDM have timesteps as arrays
423-
# It's more optimized to move all timesteps to correct device beforehand
424-
timesteps_tensor = self.scheduler.timesteps.to(self.device)
419+
# set timesteps and move to the correct device
420+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
421+
timesteps_tensor = self.scheduler.timesteps
425422

426423
# scale the initial noise by the standard deviation required by the scheduler
427424
latents = latents * self.scheduler.init_noise_sigma

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
151151
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
152152
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
153153
self.sigmas = torch.from_numpy(sigmas).to(device=device)
154-
self.timesteps = torch.from_numpy(timesteps).to(device=device)
154+
if str(device).startswith("mps"):
155+
# mps does not support float64
156+
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
157+
else:
158+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
155159

156160
def step(
157161
self,
@@ -217,8 +221,8 @@ def step(
217221

218222
prev_sample = sample + derivative * dt
219223

220-
device = model_output.device if torch.is_tensor(model_output) else "cpu"
221-
if str(device) == "mps":
224+
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
225+
if device.type == "mps":
222226
# randn does not work reproducibly on mps
223227
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
224228
device

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
152152
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
153153
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
154154
self.sigmas = torch.from_numpy(sigmas).to(device=device)
155-
self.timesteps = torch.from_numpy(timesteps).to(device=device)
155+
if str(device).startswith("mps"):
156+
# mps does not support float64
157+
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
158+
else:
159+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
156160

157161
def step(
158162
self,
@@ -214,8 +218,8 @@ def step(
214218

215219
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
216220

217-
device = model_output.device if torch.is_tensor(model_output) else "cpu"
218-
if str(device) == "mps":
221+
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
222+
if device.type == "mps":
219223
# randn does not work reproducibly on mps
220224
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
221225
device

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
173173
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
174174
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
175175
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
176+
176177
self.sigmas = torch.from_numpy(sigmas).to(device=device)
177-
self.timesteps = torch.from_numpy(timesteps).to(device=device)
178+
if str(device).startswith("mps"):
179+
# mps does not support float64
180+
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
181+
else:
182+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
178183

179184
self.derivatives = []
180185

tests/models/test_models_unet_2d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
456456
# fmt: on
457457
]
458458
)
459+
@require_torch_gpu
459460
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
460461
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
461462
latents = self.get_latents(seed)
@@ -507,6 +508,7 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
507508
# fmt: on
508509
]
509510
)
511+
@require_torch_gpu
510512
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
511513
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
512514
latents = self.get_latents(seed)
@@ -558,6 +560,7 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
558560
# fmt: on
559561
]
560562
)
563+
@require_torch_gpu
561564
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
562565
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
563566
latents = self.get_latents(seed, shape=(4, 9, 64, 64))

tests/test_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from diffusers.pipeline_utils import DiffusionPipeline
4242
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
4343
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
44-
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
44+
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
4545
from parameterized import parameterized
4646
from PIL import Image
4747
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -124,7 +124,7 @@ def test_local_custom_pipeline(self):
124124
assert output_str == "This is a local test"
125125

126126
@slow
127-
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
127+
@require_torch_gpu
128128
def test_load_pipeline_from_git(self):
129129
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
130130

tests/test_scheduler.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def check_over_configs(self, time_step=0, **config):
8383

8484
num_inference_steps = kwargs.pop("num_inference_steps", None)
8585

86-
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
8786
for scheduler_class in self.scheduler_classes:
87+
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
8888
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
8989
time_step = float(time_step)
9090

@@ -1192,6 +1192,31 @@ def test_full_loop_no_noise(self):
11921192
assert abs(result_sum.item() - 1006.388) < 1e-2
11931193
assert abs(result_mean.item() - 1.31) < 1e-3
11941194

1195+
def test_full_loop_device(self):
1196+
scheduler_class = self.scheduler_classes[0]
1197+
scheduler_config = self.get_scheduler_config()
1198+
scheduler = scheduler_class(**scheduler_config)
1199+
1200+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
1201+
1202+
model = self.dummy_model()
1203+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
1204+
sample = sample.to(torch_device)
1205+
1206+
for i, t in enumerate(scheduler.timesteps):
1207+
sample = scheduler.scale_model_input(sample, t)
1208+
1209+
model_output = model(sample, t)
1210+
1211+
output = scheduler.step(model_output, t, sample)
1212+
sample = output.prev_sample
1213+
1214+
result_sum = torch.sum(torch.abs(sample))
1215+
result_mean = torch.mean(torch.abs(sample))
1216+
1217+
assert abs(result_sum.item() - 1006.388) < 1e-2
1218+
assert abs(result_mean.item() - 1.31) < 1e-3
1219+
11951220

11961221
class EulerDiscreteSchedulerTest(SchedulerCommonTest):
11971222
scheduler_classes = (EulerDiscreteScheduler,)
@@ -1248,6 +1273,34 @@ def test_full_loop_no_noise(self):
12481273
assert abs(result_sum.item() - 10.0807) < 1e-2
12491274
assert abs(result_mean.item() - 0.0131) < 1e-3
12501275

1276+
def test_full_loop_device(self):
1277+
scheduler_class = self.scheduler_classes[0]
1278+
scheduler_config = self.get_scheduler_config()
1279+
scheduler = scheduler_class(**scheduler_config)
1280+
1281+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
1282+
1283+
generator = torch.Generator().manual_seed(0)
1284+
1285+
model = self.dummy_model()
1286+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
1287+
sample = sample.to(torch_device)
1288+
1289+
for t in scheduler.timesteps:
1290+
sample = scheduler.scale_model_input(sample, t)
1291+
1292+
model_output = model(sample, t)
1293+
1294+
output = scheduler.step(model_output, t, sample, generator=generator)
1295+
sample = output.prev_sample
1296+
1297+
result_sum = torch.sum(torch.abs(sample))
1298+
result_mean = torch.mean(torch.abs(sample))
1299+
print(result_sum, result_mean)
1300+
1301+
assert abs(result_sum.item() - 10.0807) < 1e-2
1302+
assert abs(result_mean.item() - 0.0131) < 1e-3
1303+
12511304

12521305
class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
12531306
scheduler_classes = (EulerAncestralDiscreteScheduler,)
@@ -1303,6 +1356,38 @@ def test_full_loop_no_noise(self):
13031356
assert abs(result_sum.item() - 152.3192) < 1e-2
13041357
assert abs(result_mean.item() - 0.1983) < 1e-3
13051358

1359+
def test_full_loop_device(self):
1360+
scheduler_class = self.scheduler_classes[0]
1361+
scheduler_config = self.get_scheduler_config()
1362+
scheduler = scheduler_class(**scheduler_config)
1363+
1364+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
1365+
1366+
generator = torch.Generator().manual_seed(0)
1367+
1368+
model = self.dummy_model()
1369+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
1370+
sample = sample.to(torch_device)
1371+
1372+
for t in scheduler.timesteps:
1373+
sample = scheduler.scale_model_input(sample, t)
1374+
1375+
model_output = model(sample, t)
1376+
1377+
output = scheduler.step(model_output, t, sample, generator=generator)
1378+
sample = output.prev_sample
1379+
1380+
result_sum = torch.sum(torch.abs(sample))
1381+
result_mean = torch.mean(torch.abs(sample))
1382+
print(result_sum, result_mean)
1383+
if not str(torch_device).startswith("mps"):
1384+
# The following sum varies between 148 and 156 on mps. Why?
1385+
assert abs(result_sum.item() - 152.3192) < 1e-2
1386+
assert abs(result_mean.item() - 0.1983) < 1e-3
1387+
else:
1388+
# Larger tolerance on mps
1389+
assert abs(result_mean.item() - 0.1983) < 1e-2
1390+
13061391

13071392
class IPNDMSchedulerTest(SchedulerCommonTest):
13081393
scheduler_classes = (IPNDMScheduler,)

0 commit comments

Comments
 (0)