Skip to content

Commit b85bb07

Browse files
authored
support v prediction in other schedulers (huggingface#1505)
* support v prediction in other schedulers * v heun * add tests for v pred * fix tests * fix test euler a * v ddpm
1 parent 52eb034 commit b85bb07

File tree

6 files changed

+247
-6
lines changed

6 files changed

+247
-6
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,12 @@ def step(
280280
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
281281
elif self.config.prediction_type == "sample":
282282
pred_original_sample = model_output
283+
elif self.config.prediction_type == "v_prediction":
284+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
283285
else:
284286
raise ValueError(
285-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
286-
" for the DDPMScheduler."
287+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
288+
" `v_prediction` for the DDPMScheduler."
287289
)
288290

289291
# 3. Clip "predicted x_0"

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
beta_end: float = 0.02,
7979
beta_schedule: str = "linear",
8080
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
81+
prediction_type: str = "epsilon",
8182
):
8283
if trained_betas is not None:
8384
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -202,7 +203,16 @@ def step(
202203
sigma = self.sigmas[step_index]
203204

204205
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
205-
pred_original_sample = sample - sigma * model_output
206+
if self.config.prediction_type == "epsilon":
207+
pred_original_sample = sample - sigma * model_output
208+
elif self.config.prediction_type == "v_prediction":
209+
# * c_out + input * c_skip
210+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
211+
else:
212+
raise ValueError(
213+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
214+
)
215+
206216
sigma_from = self.sigmas[step_index]
207217
sigma_to = self.sigmas[step_index + 1]
208218
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5

src/diffusers/schedulers/scheduling_heun.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
beta_end: float = 0.012,
5555
beta_schedule: str = "linear",
5656
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
57+
prediction_type: str = "epsilon",
5758
):
5859
if trained_betas is not None:
5960
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -184,7 +185,15 @@ def step(
184185
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
185186

186187
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
187-
pred_original_sample = sample - sigma_hat * model_output
188+
if self.config.prediction_type == "epsilon":
189+
pred_original_sample = sample - sigma_hat * model_output
190+
elif self.config.prediction_type == "v_prediction":
191+
# * c_out + input * c_skip
192+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
193+
else:
194+
raise ValueError(
195+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
196+
)
188197

189198
if self.state_in_first_order:
190199
# 2. Convert to an ODE derivative

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
beta_end: float = 0.02,
7979
beta_schedule: str = "linear",
8080
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
81+
prediction_type: str = "epsilon",
8182
):
8283
if trained_betas is not None:
8384
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -215,7 +216,15 @@ def step(
215216
sigma = self.sigmas[step_index]
216217

217218
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
218-
pred_original_sample = sample - sigma * model_output
219+
if self.config.prediction_type == "epsilon":
220+
pred_original_sample = sample - sigma * model_output
221+
elif self.config.prediction_type == "v_prediction":
222+
# * c_out + input * c_skip
223+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
224+
else:
225+
raise ValueError(
226+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
227+
)
219228

220229
# 2. Convert to an ODE derivative
221230
derivative = (sample - pred_original_sample) / sigma

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
103103
skip_prk_steps: bool = False,
104104
set_alpha_to_one: bool = False,
105+
prediction_type: str = "epsilon",
105106
steps_offset: int = 0,
106107
):
107108
if trained_betas is not None:
@@ -368,6 +369,13 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
368369
beta_prod_t = 1 - alpha_prod_t
369370
beta_prod_t_prev = 1 - alpha_prod_t_prev
370371

372+
if self.config.prediction_type == "v_prediction":
373+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
374+
elif self.config.prediction_type != "epsilon":
375+
raise ValueError(
376+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
377+
)
378+
371379
# corresponds to (α_(t−δ) - α_t) divided by
372380
# denominator of x_t in formula (9) and plus 1
373381
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =

0 commit comments

Comments
 (0)