@@ -162,6 +162,7 @@ def __init__(
162162 self .init_noise_sigma = 1.0
163163
164164 # setable values
165+ self .custom_timesteps = False
165166 self .num_inference_steps = None
166167 self .timesteps = torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ())
167168
@@ -191,31 +192,62 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] =
191192 """
192193 return sample
193194
194- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
195+ def set_timesteps (
196+ self ,
197+ num_inference_steps : Optional [int ] = None ,
198+ device : Union [str , torch .device ] = None ,
199+ timesteps : Optional [List [int ]] = None ,
200+ ):
195201 """
196202 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
197203
198204 Args:
199- num_inference_steps (`int`):
200- the number of diffusion steps used when generating samples with a pre-trained model.
205+ num_inference_steps (`Optional[int]`):
206+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
207+ `timesteps` must be `None`.
208+ device (`str` or `torch.device`, optional):
209+ the device to which the timesteps are moved to.
210+ custom_timesteps (`List[int]`, optional):
211+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
212+ timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
213+ must be `None`.
214+
201215 """
216+ if num_inference_steps is not None and timesteps is not None :
217+ raise ValueError ("Can only pass one of `num_inference_steps` or `custom_timesteps`." )
218+
219+ if timesteps is not None :
220+ for i in range (1 , len (timesteps )):
221+ if timesteps [i ] >= timesteps [i - 1 ]:
222+ raise ValueError ("`custom_timesteps` must be in descending order." )
223+
224+ if timesteps [0 ] >= self .config .num_train_timesteps :
225+ raise ValueError (
226+ f"`timesteps` must start before `self.config.train_timesteps`:"
227+ f" { self .config .num_train_timesteps } ."
228+ )
229+
230+ timesteps = np .array (timesteps , dtype = np .int64 )
231+ self .custom_timesteps = True
232+ else :
233+ if num_inference_steps > self .config .num_train_timesteps :
234+ raise ValueError (
235+ f"`num_inference_steps`: { num_inference_steps } cannot be larger than `self.config.train_timesteps`:"
236+ f" { self .config .num_train_timesteps } as the unet model trained with this scheduler can only handle"
237+ f" maximal { self .config .num_train_timesteps } timesteps."
238+ )
202239
203- if num_inference_steps > self .config .num_train_timesteps :
204- raise ValueError (
205- f"`num_inference_steps`: { num_inference_steps } cannot be larger than `self.config.train_timesteps`:"
206- f" { self .config .num_train_timesteps } as the unet model trained with this scheduler can only handle"
207- f" maximal { self .config .num_train_timesteps } timesteps."
208- )
240+ self .num_inference_steps = num_inference_steps
209241
210- self .num_inference_steps = num_inference_steps
242+ step_ratio = self .config .num_train_timesteps // self .num_inference_steps
243+ timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
244+ self .custom_timesteps = False
211245
212- step_ratio = self .config .num_train_timesteps // self .num_inference_steps
213- timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
214246 self .timesteps = torch .from_numpy (timesteps ).to (device )
215247
216248 def _get_variance (self , t , predicted_variance = None , variance_type = None ):
217- num_inference_steps = self .num_inference_steps if self . num_inference_steps else self . config . num_train_timesteps
218- prev_t = t - self . config . num_train_timesteps // num_inference_steps
249+ prev_t = self .previous_timestep ( t )
250+
219251 alpha_prod_t = self .alphas_cumprod [t ]
220252 alpha_prod_t_prev = self .alphas_cumprod [prev_t ] if prev_t >= 0 else self .one
221253 current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
@@ -314,8 +346,8 @@ def step(
314346
315347 """
316348 t = timestep
317- num_inference_steps = self . num_inference_steps if self . num_inference_steps else self . config . num_train_timesteps
318- prev_t = timestep - self .config . num_train_timesteps // num_inference_steps
349+
350+ prev_t = self .previous_timestep ( t )
319351
320352 if model_output .shape [1 ] == sample .shape [1 ] * 2 and self .variance_type in ["learned" , "learned_range" ]:
321353 model_output , predicted_variance = torch .split (model_output , sample .shape [1 ], dim = 1 )
@@ -428,3 +460,18 @@ def get_velocity(
428460
429461 def __len__ (self ):
430462 return self .config .num_train_timesteps
463+
464+ def previous_timestep (self , timestep ):
465+ if self .custom_timesteps :
466+ index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
467+ if index == self .timesteps .shape [0 ] - 1 :
468+ prev_t = torch .tensor (- 1 )
469+ else :
470+ prev_t = self .timesteps [index + 1 ]
471+ else :
472+ num_inference_steps = (
473+ self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
474+ )
475+ prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
476+
477+ return prev_t
0 commit comments