Skip to content

Commit a3efa43

Browse files
Fix DDIM on Windows not using int64 for timesteps (huggingface#819)
1 parent 728a3f3 commit a3efa43

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149

150150
# setable values
151151
self.num_inference_steps = None
152-
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
152+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
153153

154154
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
155155
"""
@@ -192,7 +192,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
192192
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
193193
# creates integer timesteps by multiplying by ratio
194194
# casting to int to avoid issues when num_inference_step is power of 3
195-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
195+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
196196
self.timesteps = torch.from_numpy(timesteps).to(device)
197197
self.timesteps += offset
198198

0 commit comments

Comments
 (0)