Skip to content

Commit ca68ab3

Browse files
Update scheduling_repaint.py (huggingface#1582)
* Update scheduling_repaint.py * update the expected image Co-authored-by: anton- <[email protected]>
1 parent ced7c96 commit ca68ab3

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

src/diffusers/schedulers/scheduling_repaint.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def step(
287287
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
288288

289289
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
290-
prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise
290+
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
291291

292292
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
293293
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part

tests/pipelines/repaint/test_repaint.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
22-
from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device
22+
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device
2323

2424

2525
torch.backends.cuda.matmul.allow_tf32 = False
@@ -36,11 +36,10 @@ def test_celebahq(self):
3636
mask_image = load_image(
3737
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
3838
)
39-
expected_image = load_image(
39+
expected_image = load_numpy(
4040
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
41-
"repaint/celeba_hq_256_result.png"
41+
"repaint/celeba_hq_256_result.npy"
4242
)
43-
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
4443

4544
model_id = "google/ddpm-ema-celebahq-256"
4645
unet = UNet2DModel.from_pretrained(model_id)

0 commit comments

Comments
 (0)