@@ -158,6 +158,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
158158 use_karras_sigmas (`bool`, *optional*, defaults to `False`):
159159 Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
160160 the sigmas are determined according to a sequence of noise levels {σi}.
161+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
162+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
161163 timestep_spacing (`str`, defaults to `"linspace"`):
162164 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
163165 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -186,6 +188,7 @@ def __init__(
186188 prediction_type : str = "epsilon" ,
187189 interpolation_type : str = "linear" ,
188190 use_karras_sigmas : Optional [bool ] = False ,
191+ use_exponential_sigmas : Optional [bool ] = False ,
189192 sigma_min : Optional [float ] = None ,
190193 sigma_max : Optional [float ] = None ,
191194 timestep_spacing : str = "linspace" ,
@@ -235,6 +238,7 @@ def __init__(
235238
236239 self .is_scale_input_called = False
237240 self .use_karras_sigmas = use_karras_sigmas
241+ self .use_exponential_sigmas = use_exponential_sigmas
238242
239243 self ._step_index = None
240244 self ._begin_index = None
@@ -332,6 +336,12 @@ def set_timesteps(
332336 raise ValueError ("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`." )
333337 if timesteps is not None and self .config .use_karras_sigmas :
334338 raise ValueError ("Cannot set `timesteps` with `config.use_karras_sigmas = True`." )
339+ if timesteps is not None and self .config .use_exponential_sigmas :
340+ raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
341+ if self .config .use_exponential_sigmas and self .config .use_karras_sigmas :
342+ raise ValueError (
343+ "Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`"
344+ )
335345 if (
336346 timesteps is not None
337347 and self .config .timestep_type == "continuous"
@@ -396,6 +406,10 @@ def set_timesteps(
396406 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
397407 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
398408
409+ elif self .config .use_exponential_sigmas :
410+ sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
411+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
412+
399413 if self .config .final_sigmas_type == "sigma_min" :
400414 sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
401415 elif self .config .final_sigmas_type == "zero" :
@@ -468,6 +482,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
468482 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
469483 return sigmas
470484
485+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
486+ def _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
487+ """Constructs an exponential noise schedule."""
488+
489+ # Hack to make sure that other schedulers which copy this function don't break
490+ # TODO: Add this logic to the other schedulers
491+ if hasattr (self .config , "sigma_min" ):
492+ sigma_min = self .config .sigma_min
493+ else :
494+ sigma_min = None
495+
496+ if hasattr (self .config , "sigma_max" ):
497+ sigma_max = self .config .sigma_max
498+ else :
499+ sigma_max = None
500+
501+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
502+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
503+
504+ sigmas = torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
505+ return sigmas
506+
471507 def index_for_timestep (self , timestep , schedule_timesteps = None ):
472508 if schedule_timesteps is None :
473509 schedule_timesteps = self .timesteps
0 commit comments