Skip to content

Commit b64c522

Browse files
[PNDM Scheduler] format timesteps attrs to np arrays (huggingface#273)
* format timesteps attrs to np arrays in pndm scheduler because lists don't get formatted to tensors in `self.set_format` * convert to long type to use timesteps as indices for tensors * add scheduler set_format test * fix `_timesteps` type * make style with black 22.3.0 and isort 5.10.1 Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7eb6dfc commit b64c522

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,24 @@ def set_timesteps(self, num_inference_steps, offset=0):
103103
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
104104
)
105105
self._offset = offset
106-
self._timesteps = [t + self._offset for t in self._timesteps]
106+
self._timesteps = np.array([t + self._offset for t in self._timesteps])
107107

108108
if self.config.skip_prk_steps:
109109
# for some models like stable diffusion the prk steps can/should be skipped to
110110
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
111111
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
112-
self.prk_timesteps = []
113-
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
112+
self.prk_timesteps = np.array([])
113+
self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy()
114114
else:
115115
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
116116
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
117117
)
118-
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
119-
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
118+
self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
119+
self.plms_timesteps = self._timesteps[:-3][
120+
::-1
121+
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
120122

121-
self.timesteps = self.prk_timesteps + self.plms_timesteps
123+
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
122124

123125
self.ets = []
124126
self.counter = 0

tests/test_scheduler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,35 @@ def test_pytorch_equal_numpy(self):
485485

486486
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
487487

488+
def test_set_format(self):
489+
kwargs = dict(self.forward_default_kwargs)
490+
num_inference_steps = kwargs.pop("num_inference_steps", None)
491+
492+
for scheduler_class in self.scheduler_classes:
493+
scheduler_config = self.get_scheduler_config()
494+
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
495+
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
496+
497+
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
498+
scheduler.set_timesteps(num_inference_steps)
499+
scheduler_pt.set_timesteps(num_inference_steps)
500+
501+
for key, value in vars(scheduler).items():
502+
# we only allow `ets` attr to be a list
503+
assert not isinstance(value, list) or key in [
504+
"ets"
505+
], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
506+
507+
# check if `scheduler.set_format` does convert correctly attrs to pt format
508+
for key, value in vars(scheduler_pt).items():
509+
# we only allow `ets` attr to be a list
510+
assert not isinstance(value, list) or key in [
511+
"ets"
512+
], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
513+
assert not isinstance(
514+
value, np.ndarray
515+
), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
516+
488517
def test_step_shape(self):
489518
kwargs = dict(self.forward_default_kwargs)
490519

0 commit comments

Comments
 (0)