Skip to content

Commit 55660cf

Browse files
clarencechenPeterL1npatrickvonplaten
authored
Improve dynamic thresholding and extend to DDPM and DDIM Schedulers (huggingface#2528)
* Improve dynamic threshold * Update code * Add dynamic threshold to ddim and ddpm * Encapsulate and leverage code copy mechanism Update style * Clean up DDPM/DDIM constructor arguments * add test * also add to unipc --------- Co-authored-by: Peter Lin <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 46bef6e commit 55660cf

12 files changed

+171
-60
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7070
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
7171
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
7272
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
73-
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
74-
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
73+
class_embed_type (`str`, *optional*, defaults to None):
74+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
75+
`"timestep"`, or `"identity"`.
7576
num_class_embeds (`int`, *optional*, defaults to None):
7677
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
7778
class conditioning with `class_embed_type` equal to `None`.

src/diffusers/models/unet_2d_condition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
9090
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
9191
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
9292
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
93-
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
94-
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
93+
class_embed_type (`str`, *optional*, defaults to None):
94+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
95+
`"timestep"`, `"identity"`, or `"projection"`.
9596
num_class_embeds (`int`, *optional*, defaults to None):
9697
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
9798
class conditioning with `class_embed_type` equal to `None`.

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
171171
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
172172
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
173173
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
174-
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
175-
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
174+
class_embed_type (`str`, *optional*, defaults to None):
175+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
176+
`"timestep"`, `"identity"`, or `"projection"`.
176177
num_class_embeds (`int`, *optional*, defaults to None):
177178
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
178179
class conditioning with `class_embed_type` equal to `None`.

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
9898
trained_betas (`np.ndarray`, optional):
9999
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
100100
clip_sample (`bool`, default `True`):
101-
option to clip predicted sample between -1 and 1 for numerical stability.
101+
option to clip predicted sample for numerical stability.
102+
clip_sample_range (`float`, default `1.0`):
103+
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
102104
set_alpha_to_one (`bool`, default `True`):
103105
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
104106
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
@@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
111113
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
112114
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
113115
https://imagen.research.google/video/paper.pdf)
116+
thresholding (`bool`, default `False`):
117+
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
118+
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
119+
stable-diffusion).
120+
dynamic_thresholding_ratio (`float`, default `0.995`):
121+
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
122+
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
123+
sample_max_value (`float`, default `1.0`):
124+
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
114125
"""
115126

116127
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -128,6 +139,10 @@ def __init__(
128139
set_alpha_to_one: bool = True,
129140
steps_offset: int = 0,
130141
prediction_type: str = "epsilon",
142+
thresholding: bool = False,
143+
dynamic_thresholding_ratio: float = 0.995,
144+
clip_sample_range: float = 1.0,
145+
sample_max_value: float = 1.0,
131146
):
132147
if trained_betas is not None:
133148
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -184,6 +199,18 @@ def _get_variance(self, timestep, prev_timestep):
184199

185200
return variance
186201

202+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
203+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
204+
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
205+
dynamic_max_val = (
206+
sample.flatten(1)
207+
.abs()
208+
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
209+
.clamp_min(self.config.sample_max_value)
210+
.view(-1, *([1] * (sample.ndim - 1)))
211+
)
212+
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
213+
187214
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
188215
"""
189216
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -286,9 +313,14 @@ def step(
286313
" `v_prediction`"
287314
)
288315

289-
# 4. Clip "predicted x_0"
316+
# 4. Clip or threshold "predicted x_0"
290317
if self.config.clip_sample:
291-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
318+
pred_original_sample = pred_original_sample.clamp(
319+
-self.config.clip_sample_range, self.config.clip_sample_range
320+
)
321+
322+
if self.config.thresholding:
323+
pred_original_sample = self._threshold_sample(pred_original_sample)
292324

293325
# 5. compute variance: "sigma_t(η)" -> see formula (16)
294326
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
9898
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
9999
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
100100
clip_sample (`bool`, default `True`):
101-
option to clip predicted sample between -1 and 1 for numerical stability.
101+
option to clip predicted sample for numerical stability.
102+
clip_sample_range (`float`, default `1.0`):
103+
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
102104
prediction_type (`str`, default `epsilon`, optional):
103105
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
104106
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
105107
https://imagen.research.google/video/paper.pdf)
108+
thresholding (`bool`, default `False`):
109+
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
110+
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
111+
stable-diffusion).
112+
dynamic_thresholding_ratio (`float`, default `0.995`):
113+
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
114+
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
115+
sample_max_value (`float`, default `1.0`):
116+
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
106117
"""
107118

108119
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -119,7 +130,10 @@ def __init__(
119130
variance_type: str = "fixed_small",
120131
clip_sample: bool = True,
121132
prediction_type: str = "epsilon",
122-
clip_sample_range: Optional[float] = 1.0,
133+
thresholding: bool = False,
134+
dynamic_thresholding_ratio: float = 0.995,
135+
clip_sample_range: float = 1.0,
136+
sample_max_value: float = 1.0,
123137
):
124138
if trained_betas is not None:
125139
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -226,6 +240,17 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
226240

227241
return variance
228242

243+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
244+
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
245+
dynamic_max_val = (
246+
sample.flatten(1)
247+
.abs()
248+
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
249+
.clamp_min(self.config.sample_max_value)
250+
.view(-1, *([1] * (sample.ndim - 1)))
251+
)
252+
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
253+
229254
def step(
230255
self,
231256
model_output: torch.FloatTensor,
@@ -283,12 +308,15 @@ def step(
283308
" `v_prediction` for the DDPMScheduler."
284309
)
285310

286-
# 3. Clip "predicted x_0"
311+
# 3. Clip or threshold "predicted x_0"
287312
if self.config.clip_sample:
288-
pred_original_sample = torch.clamp(
289-
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
313+
pred_original_sample = pred_original_sample.clamp(
314+
-self.config.clip_sample_range, self.config.clip_sample_range
290315
)
291316

317+
if self.config.thresholding:
318+
pred_original_sample = self._threshold_sample(pred_original_sample)
319+
292320
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
293321
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
294322
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
9696
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
9797
(https://arxiv.org/abs/2205.11487).
9898
sample_max_value (`float`, default `1.0`):
99-
the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
99+
the threshold value for dynamic thresholding. Valid only when `thresholding=True`
100100
algorithm_type (`str`, default `deis`):
101101
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
102102
the future
@@ -194,6 +194,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
194194
] * self.config.solver_order
195195
self.lower_order_nums = 0
196196

197+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
198+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
199+
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
200+
dynamic_max_val = (
201+
sample.flatten(1)
202+
.abs()
203+
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
204+
.clamp_min(self.config.sample_max_value)
205+
.view(-1, *([1] * (sample.ndim - 1)))
206+
)
207+
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
208+
197209
def convert_model_output(
198210
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
199211
) -> torch.FloatTensor:
@@ -228,15 +240,7 @@ def convert_model_output(
228240
orig_dtype = x0_pred.dtype
229241
if orig_dtype not in [torch.float, torch.double]:
230242
x0_pred = x0_pred.float()
231-
dynamic_max_val = torch.quantile(
232-
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
233-
)
234-
dynamic_max_val = torch.maximum(
235-
dynamic_max_val,
236-
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
237-
)[(...,) + (None,) * (x0_pred.ndim - 1)]
238-
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
239-
x0_pred = x0_pred.type(orig_dtype)
243+
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
240244

241245
if self.config.algorithm_type == "deis":
242246
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
204204
] * self.config.solver_order
205205
self.lower_order_nums = 0
206206

207+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
208+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
209+
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
210+
dynamic_max_val = (
211+
sample.flatten(1)
212+
.abs()
213+
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
214+
.clamp_min(self.config.sample_max_value)
215+
.view(-1, *([1] * (sample.ndim - 1)))
216+
)
217+
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
218+
207219
def convert_model_output(
208220
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
209221
) -> torch.FloatTensor:
@@ -247,15 +259,7 @@ def convert_model_output(
247259
orig_dtype = x0_pred.dtype
248260
if orig_dtype not in [torch.float, torch.double]:
249261
x0_pred = x0_pred.float()
250-
dynamic_max_val = torch.quantile(
251-
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
252-
)
253-
dynamic_max_val = torch.maximum(
254-
dynamic_max_val,
255-
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
256-
)[(...,) + (None,) * (x0_pred.ndim - 1)]
257-
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
258-
x0_pred = x0_pred.type(orig_dtype)
262+
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
259263
return x0_pred
260264
# DPM-Solver needs to solve an integral of the noise prediction model.
261265
elif self.config.algorithm_type == "dpmsolver":

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
237237
self.sample = None
238238
self.orders = self.get_order_list(num_inference_steps)
239239

240+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
241+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
242+
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
243+
dynamic_max_val = (
244+
sample.flatten(1)
245+
.abs()
246+
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
247+
.clamp_min(self.config.sample_max_value)
248+
.view(-1, *([1] * (sample.ndim - 1)))
249+
)
250+
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
251+
240252
def convert_model_output(
241253
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
242254
) -> torch.FloatTensor:
@@ -277,18 +289,10 @@ def convert_model_output(
277289

278290
if self.config.thresholding:
279291
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
280-
dtype = x0_pred.dtype
281-
dynamic_max_val = torch.quantile(
282-
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(),
283-
self.config.dynamic_thresholding_ratio,
284-
dim=1,
285-
)
286-
dynamic_max_val = torch.maximum(
287-
dynamic_max_val,
288-
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
289-
)[(...,) + (None,) * (x0_pred.ndim - 1)]
290-
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
291-
x0_pred = x0_pred.to(dtype)
292+
orig_dtype = x0_pred.dtype
293+
if orig_dtype not in [torch.float, torch.double]:
294+
x0_pred = x0_pred.float()
295+
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
292296
return x0_pred
293297
# DPM-Solver needs to solve an integral of the noise prediction model.
294298
elif self.config.algorithm_type == "dpmsolver":

src/diffusers/schedulers/scheduling_sde_ve.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def set_timesteps(
109109
Args:
110110
num_inference_steps (`int`):
111111
the number of diffusion steps used when generating samples with a pre-trained model.
112-
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
112+
sampling_eps (`float`, optional):
113+
final timestep value (overrides value given at Scheduler instantiation).
113114
114115
"""
115116
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
@@ -129,8 +130,10 @@ def set_sigmas(
129130
the number of diffusion steps used when generating samples with a pre-trained model.
130131
sigma_min (`float`, optional):
131132
initial noise scale value (overrides value given at Scheduler instantiation).
132-
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
133-
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
133+
sigma_max (`float`, optional):
134+
final noise scale value (overrides value given at Scheduler instantiation).
135+
sampling_eps (`float`, optional):
136+
final timestep value (overrides value given at Scheduler instantiation).
134137
135138
"""
136139
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min

0 commit comments

Comments
 (0)