Skip to content

Commit d1efefe

Browse files
authored
[Breaking change] fix legacy inpaint noise and resize mask tensor (huggingface#2147)
* fix legacy inpaint noise and resize mask tensor * updated legacy inpaint pipe test expected_slice
1 parent 7d96b38 commit d1efefe

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,34 @@ def preprocess_image(image):
4545

4646

4747
def preprocess_mask(mask, scale_factor=8):
48-
mask = mask.convert("L")
49-
w, h = mask.size
50-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
51-
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
52-
mask = np.array(mask).astype(np.float32) / 255.0
53-
mask = np.tile(mask, (4, 1, 1))
54-
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
55-
mask = 1 - mask # repaint white, keep black
56-
mask = torch.from_numpy(mask)
57-
return mask
48+
49+
if not isinstance(mask, torch.FloatTensor):
50+
mask = mask.convert("L")
51+
w, h = mask.size
52+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
53+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
54+
mask = np.array(mask).astype(np.float32) / 255.0
55+
mask = np.tile(mask, (4, 1, 1))
56+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
57+
mask = 1 - mask # repaint white, keep black
58+
mask = torch.from_numpy(mask)
59+
return mask
60+
61+
else:
62+
valid_mask_channel_sizes = [1, 3]
63+
# if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
64+
if mask.shape[3] in valid_mask_channel_sizes:
65+
mask = mask.permute(0, 3, 1, 2)
66+
elif mask.shape[1] not in valid_mask_channel_sizes:
67+
raise ValueError(
68+
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension, but received mask of shape {tuple(mask.shape)}"
69+
)
70+
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
71+
mask = mask.mean(dim=1, keepdim=True)
72+
h, w = mask.shape[-2:]
73+
h, w = map(lambda x: x - x % 32, (h, w)) # resize to integer multiple of 32
74+
mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
75+
return mask
5876

5977

6078
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
@@ -497,8 +515,8 @@ def __call__(
497515
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
498516
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
499517
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
500-
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
501-
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
518+
PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
519+
expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3.
502520
strength (`float`, *optional*, defaults to 0.8):
503521
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
504522
is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -585,8 +603,7 @@ def __call__(
585603
if not isinstance(image, torch.FloatTensor):
586604
image = preprocess_image(image)
587605

588-
if not isinstance(mask_image, torch.FloatTensor):
589-
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
606+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
590607

591608
# 5. set timesteps
592609
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -640,6 +657,9 @@ def __call__(
640657
if callback is not None and i % callback_steps == 0:
641658
callback(i, t, latents)
642659

660+
# use original latents corresponding to unmasked portions of the image
661+
latents = (init_latents_orig * mask) + (latents * (1 - mask))
662+
643663
# 10. Post-processing
644664
image = self.decode_latents(latents)
645665

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ def test_stable_diffusion_inpaint_legacy(self):
212212
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
213213

214214
assert image.shape == (1, 32, 32, 3)
215-
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
216-
215+
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
216+
217217
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
218218
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
219219

@@ -260,7 +260,7 @@ def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
260260
image_slice = image[0, -3:, -3:, -1]
261261

262262
assert image.shape == (1, 32, 32, 3)
263-
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
263+
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
264264

265265
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
266266

0 commit comments

Comments
 (0)