@@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
9898 trained_betas (`np.ndarray`, optional):
9999 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
100100 clip_sample (`bool`, default `True`):
101- option to clip predicted sample between -1 and 1 for numerical stability.
101+ option to clip predicted sample for numerical stability.
102+ clip_sample_range (`float`, default `1.0`):
103+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
102104 set_alpha_to_one (`bool`, default `True`):
103105 each diffusion step uses the value of alphas product at that step and at the previous one. For the final
104106 step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
@@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
111113 prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
112114 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
113115 https://imagen.research.google/video/paper.pdf)
116+ thresholding (`bool`, default `False`):
117+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
118+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
119+ stable-diffusion).
120+ dynamic_thresholding_ratio (`float`, default `0.995`):
121+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
122+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
123+ sample_max_value (`float`, default `1.0`):
124+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
114125 """
115126
116127 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -128,6 +139,10 @@ def __init__(
128139 set_alpha_to_one : bool = True ,
129140 steps_offset : int = 0 ,
130141 prediction_type : str = "epsilon" ,
142+ thresholding : bool = False ,
143+ dynamic_thresholding_ratio : float = 0.995 ,
144+ clip_sample_range : float = 1.0 ,
145+ sample_max_value : float = 1.0 ,
131146 ):
132147 if trained_betas is not None :
133148 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -184,6 +199,18 @@ def _get_variance(self, timestep, prev_timestep):
184199
185200 return variance
186201
202+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
203+ def _threshold_sample (self , sample : torch .FloatTensor ) -> torch .FloatTensor :
204+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
205+ dynamic_max_val = (
206+ sample .flatten (1 )
207+ .abs ()
208+ .quantile (self .config .dynamic_thresholding_ratio , dim = 1 )
209+ .clamp_min (self .config .sample_max_value )
210+ .view (- 1 , * ([1 ] * (sample .ndim - 1 )))
211+ )
212+ return sample .clamp (- dynamic_max_val , dynamic_max_val ) / dynamic_max_val
213+
187214 def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
188215 """
189216 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -286,9 +313,14 @@ def step(
286313 " `v_prediction`"
287314 )
288315
289- # 4. Clip "predicted x_0"
316+ # 4. Clip or threshold "predicted x_0"
290317 if self .config .clip_sample :
291- pred_original_sample = torch .clamp (pred_original_sample , - 1 , 1 )
318+ pred_original_sample = pred_original_sample .clamp (
319+ - self .config .clip_sample_range , self .config .clip_sample_range
320+ )
321+
322+ if self .config .thresholding :
323+ pred_original_sample = self ._threshold_sample (pred_original_sample )
292324
293325 # 5. compute variance: "sigma_t(η)" -> see formula (16)
294326 # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
0 commit comments