Skip to content

Commit ecadcde

Browse files
[Bug] scheduling_ddpm: fix variance in the case of learned_range type. (huggingface#2090)
scheduling_ddpm: fix variance in the case of learned_range type. In the case of learned_range variance type, there are missing logs and exponent comparing to the theory (see "Improved Denoising Diffusion Probabilistic Models" section 3.1 equation 15: https://arxiv.org/pdf/2102.09672.pdf).
1 parent 2bbd532 commit ecadcde

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
218218
elif variance_type == "learned":
219219
return predicted_variance
220220
elif variance_type == "learned_range":
221-
min_log = variance
222-
max_log = self.betas[t]
221+
min_log = torch.log(variance)
222+
max_log = torch.log(self.betas[t])
223223
frac = (predicted_variance + 1) / 2
224224
variance = frac * max_log + (1 - frac) * min_log
225225

@@ -304,6 +304,9 @@ def step(
304304
)
305305
if self.variance_type == "fixed_small_log":
306306
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
307+
elif self.variance_type == "learned_range":
308+
variance = self._get_variance(t, predicted_variance=predicted_variance)
309+
variance = torch.exp(0.5 * variance) * variance_noise
307310
else:
308311
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
309312

0 commit comments

Comments
 (0)