Skip to content

Commit 80c00e5

Browse files
yiyixuxuyiyixuxu
andauthored
add use_karras_sigmas to KDPM2DiscreteScheduler and KDPM2AncestralDiscreteScheduler (huggingface#5111)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 2badddf commit 80c00e5

File tree

3 files changed

+99
-36
lines changed

3 files changed

+99
-36
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ def model_fn(x, t):
609609
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
610610
sampler_kwargs["noise_sampler"] = noise_sampler
611611

612+
if "generator" in inspect.signature(self.sampler).parameters:
613+
sampler_kwargs["generator"] = generator
614+
612615
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
613616

614617
if not output_type == "latent":

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
8989
`linear` or `scaled_linear`.
9090
trained_betas (`np.ndarray`, *optional*):
9191
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
92+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
93+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
94+
the sigmas are determined according to a sequence of noise levels {σi}.
9295
prediction_type (`str`, defaults to `epsilon`, *optional*):
9396
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
9497
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -113,6 +116,7 @@ def __init__(
113116
beta_end: float = 0.012,
114117
beta_schedule: str = "linear",
115118
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
119+
use_karras_sigmas: Optional[bool] = False,
116120
prediction_type: str = "epsilon",
117121
timestep_spacing: str = "linspace",
118122
steps_offset: int = 0,
@@ -243,9 +247,15 @@ def set_timesteps(
243247
)
244248

245249
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
246-
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
250+
log_sigmas = np.log(sigmas)
247251

248252
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
253+
254+
if self.config.use_karras_sigmas:
255+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
256+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
257+
258+
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
249259
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
250260
sigmas = torch.from_numpy(sigmas).to(device=device)
251261

@@ -269,7 +279,13 @@ def set_timesteps(
269279
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
270280

271281
timesteps = torch.from_numpy(timesteps).to(device)
272-
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
282+
sigmas_interpol = sigmas_interpol.cpu()
283+
log_sigmas = self.log_sigmas.cpu()
284+
timesteps_interpol = np.array(
285+
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
286+
)
287+
288+
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
273289
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
274290

275291
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
@@ -282,29 +298,44 @@ def set_timesteps(
282298

283299
self._step_index = None
284300

285-
def sigma_to_t(self, sigma):
301+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
302+
def _sigma_to_t(self, sigma, log_sigmas):
286303
# get log sigma
287-
log_sigma = sigma.log()
304+
log_sigma = np.log(sigma)
288305

289306
# get distribution
290-
dists = log_sigma - self.log_sigmas[:, None]
307+
dists = log_sigma - log_sigmas[:, np.newaxis]
291308

292309
# get sigmas range
293-
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
310+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
294311
high_idx = low_idx + 1
295312

296-
low = self.log_sigmas[low_idx]
297-
high = self.log_sigmas[high_idx]
313+
low = log_sigmas[low_idx]
314+
high = log_sigmas[high_idx]
298315

299316
# interpolate sigmas
300317
w = (low - log_sigma) / (low - high)
301-
w = w.clamp(0, 1)
318+
w = np.clip(w, 0, 1)
302319

303320
# transform interpolation to time range
304321
t = (1 - w) * low_idx + w * high_idx
305-
t = t.view(sigma.shape)
322+
t = t.reshape(sigma.shape)
306323
return t
307324

325+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
326+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
327+
"""Constructs the noise schedule of Karras et al. (2022)."""
328+
329+
sigma_min: float = in_sigmas[-1].item()
330+
sigma_max: float = in_sigmas[0].item()
331+
332+
rho = 7.0 # 7.0 is the value used in the paper
333+
ramp = np.linspace(0, 1, num_inference_steps)
334+
min_inv_rho = sigma_min ** (1 / rho)
335+
max_inv_rho = sigma_max ** (1 / rho)
336+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
337+
return sigmas
338+
308339
@property
309340
def state_in_first_order(self):
310341
return self.sample is None

src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
8888
`linear` or `scaled_linear`.
8989
trained_betas (`np.ndarray`, *optional*):
9090
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
91+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
92+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
93+
the sigmas are determined according to a sequence of noise levels {σi}.
9194
prediction_type (`str`, defaults to `epsilon`, *optional*):
9295
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
9396
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -112,6 +115,7 @@ def __init__(
112115
beta_end: float = 0.012,
113116
beta_schedule: str = "linear",
114117
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
118+
use_karras_sigmas: Optional[bool] = False,
115119
prediction_type: str = "epsilon",
116120
timestep_spacing: str = "linspace",
117121
steps_offset: int = 0,
@@ -243,9 +247,14 @@ def set_timesteps(
243247
)
244248

245249
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
246-
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
247-
250+
log_sigmas = np.log(sigmas)
248251
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
252+
253+
if self.config.use_karras_sigmas:
254+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
255+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
256+
257+
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
249258
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
250259
sigmas = torch.from_numpy(sigmas).to(device=device)
251260

@@ -260,7 +269,12 @@ def set_timesteps(
260269
timesteps = torch.from_numpy(timesteps).to(device)
261270

262271
# interpolate timesteps
263-
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
272+
sigmas_interpol = sigmas_interpol.cpu()
273+
log_sigmas = self.log_sigmas.cpu()
274+
timesteps_interpol = np.array(
275+
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
276+
)
277+
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
264278
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
265279

266280
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
@@ -273,29 +287,6 @@ def set_timesteps(
273287

274288
self._step_index = None
275289

276-
def sigma_to_t(self, sigma):
277-
# get log sigma
278-
log_sigma = sigma.log()
279-
280-
# get distribution
281-
dists = log_sigma - self.log_sigmas[:, None]
282-
283-
# get sigmas range
284-
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
285-
high_idx = low_idx + 1
286-
287-
low = self.log_sigmas[low_idx]
288-
high = self.log_sigmas[high_idx]
289-
290-
# interpolate sigmas
291-
w = (low - log_sigma) / (low - high)
292-
w = w.clamp(0, 1)
293-
294-
# transform interpolation to time range
295-
t = (1 - w) * low_idx + w * high_idx
296-
t = t.view(sigma.shape)
297-
return t
298-
299290
@property
300291
def state_in_first_order(self):
301292
return self.sample is None
@@ -318,6 +309,44 @@ def _init_step_index(self, timestep):
318309

319310
self._step_index = step_index.item()
320311

312+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
313+
def _sigma_to_t(self, sigma, log_sigmas):
314+
# get log sigma
315+
log_sigma = np.log(sigma)
316+
317+
# get distribution
318+
dists = log_sigma - log_sigmas[:, np.newaxis]
319+
320+
# get sigmas range
321+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
322+
high_idx = low_idx + 1
323+
324+
low = log_sigmas[low_idx]
325+
high = log_sigmas[high_idx]
326+
327+
# interpolate sigmas
328+
w = (low - log_sigma) / (low - high)
329+
w = np.clip(w, 0, 1)
330+
331+
# transform interpolation to time range
332+
t = (1 - w) * low_idx + w * high_idx
333+
t = t.reshape(sigma.shape)
334+
return t
335+
336+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
337+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
338+
"""Constructs the noise schedule of Karras et al. (2022)."""
339+
340+
sigma_min: float = in_sigmas[-1].item()
341+
sigma_max: float = in_sigmas[0].item()
342+
343+
rho = 7.0 # 7.0 is the value used in the paper
344+
ramp = np.linspace(0, 1, num_inference_steps)
345+
min_inv_rho = sigma_min ** (1 / rho)
346+
max_inv_rho = sigma_max ** (1 / rho)
347+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
348+
return sigmas
349+
321350
def step(
322351
self,
323352
model_output: Union[torch.FloatTensor, np.ndarray],

0 commit comments

Comments
 (0)