@@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
8888 `linear` or `scaled_linear`.
8989 trained_betas (`np.ndarray`, *optional*):
9090 Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
91+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
92+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
93+ the sigmas are determined according to a sequence of noise levels {σi}.
9194 prediction_type (`str`, defaults to `epsilon`, *optional*):
9295 Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
9396 `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -112,6 +115,7 @@ def __init__(
112115 beta_end : float = 0.012 ,
113116 beta_schedule : str = "linear" ,
114117 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
118+ use_karras_sigmas : Optional [bool ] = False ,
115119 prediction_type : str = "epsilon" ,
116120 timestep_spacing : str = "linspace" ,
117121 steps_offset : int = 0 ,
@@ -243,9 +247,14 @@ def set_timesteps(
243247 )
244248
245249 sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
246- self .log_sigmas = torch .from_numpy (np .log (sigmas )).to (device )
247-
250+ log_sigmas = np .log (sigmas )
248251 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
252+
253+ if self .config .use_karras_sigmas :
254+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
255+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
256+
257+ self .log_sigmas = torch .from_numpy (log_sigmas ).to (device = device )
249258 sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
250259 sigmas = torch .from_numpy (sigmas ).to (device = device )
251260
@@ -260,7 +269,12 @@ def set_timesteps(
260269 timesteps = torch .from_numpy (timesteps ).to (device )
261270
262271 # interpolate timesteps
263- timesteps_interpol = self .sigma_to_t (sigmas_interpol ).to (device , dtype = timesteps .dtype )
272+ sigmas_interpol = sigmas_interpol .cpu ()
273+ log_sigmas = self .log_sigmas .cpu ()
274+ timesteps_interpol = np .array (
275+ [self ._sigma_to_t (sigma_interpol , log_sigmas ) for sigma_interpol in sigmas_interpol ]
276+ )
277+ timesteps_interpol = torch .from_numpy (timesteps_interpol ).to (device , dtype = timesteps .dtype )
264278 interleaved_timesteps = torch .stack ((timesteps_interpol [1 :- 1 , None ], timesteps [1 :, None ]), dim = - 1 ).flatten ()
265279
266280 self .timesteps = torch .cat ([timesteps [:1 ], interleaved_timesteps ])
@@ -273,29 +287,6 @@ def set_timesteps(
273287
274288 self ._step_index = None
275289
276- def sigma_to_t (self , sigma ):
277- # get log sigma
278- log_sigma = sigma .log ()
279-
280- # get distribution
281- dists = log_sigma - self .log_sigmas [:, None ]
282-
283- # get sigmas range
284- low_idx = dists .ge (0 ).cumsum (dim = 0 ).argmax (dim = 0 ).clamp (max = self .log_sigmas .shape [0 ] - 2 )
285- high_idx = low_idx + 1
286-
287- low = self .log_sigmas [low_idx ]
288- high = self .log_sigmas [high_idx ]
289-
290- # interpolate sigmas
291- w = (low - log_sigma ) / (low - high )
292- w = w .clamp (0 , 1 )
293-
294- # transform interpolation to time range
295- t = (1 - w ) * low_idx + w * high_idx
296- t = t .view (sigma .shape )
297- return t
298-
299290 @property
300291 def state_in_first_order (self ):
301292 return self .sample is None
@@ -318,6 +309,44 @@ def _init_step_index(self, timestep):
318309
319310 self ._step_index = step_index .item ()
320311
312+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
313+ def _sigma_to_t (self , sigma , log_sigmas ):
314+ # get log sigma
315+ log_sigma = np .log (sigma )
316+
317+ # get distribution
318+ dists = log_sigma - log_sigmas [:, np .newaxis ]
319+
320+ # get sigmas range
321+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
322+ high_idx = low_idx + 1
323+
324+ low = log_sigmas [low_idx ]
325+ high = log_sigmas [high_idx ]
326+
327+ # interpolate sigmas
328+ w = (low - log_sigma ) / (low - high )
329+ w = np .clip (w , 0 , 1 )
330+
331+ # transform interpolation to time range
332+ t = (1 - w ) * low_idx + w * high_idx
333+ t = t .reshape (sigma .shape )
334+ return t
335+
336+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
337+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
338+ """Constructs the noise schedule of Karras et al. (2022)."""
339+
340+ sigma_min : float = in_sigmas [- 1 ].item ()
341+ sigma_max : float = in_sigmas [0 ].item ()
342+
343+ rho = 7.0 # 7.0 is the value used in the paper
344+ ramp = np .linspace (0 , 1 , num_inference_steps )
345+ min_inv_rho = sigma_min ** (1 / rho )
346+ max_inv_rho = sigma_max ** (1 / rho )
347+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
348+ return sigmas
349+
321350 def step (
322351 self ,
323352 model_output : Union [torch .FloatTensor , np .ndarray ],
0 commit comments