Skip to content

Commit 4b45a1e

Browse files
authored
[docs] Use other checkpoints with inpaint (huggingface#5590)
* tip about inpaint checkpoints * expand section * feedback
1 parent f782ca1 commit 4b45a1e

File tree

1 file changed

+177
-45
lines changed

1 file changed

+177
-45
lines changed

docs/source/en/using-diffusers/inpaint.md

Lines changed: 177 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,183 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
184184
</div>
185185
</div>
186186

187+
## Non-inpaint specific checkpoints
188+
189+
So far, this guide has used inpaint specific checkpoints such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). But you can also use regular checkpoints like [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Let's compare the results of the two checkpoints.
190+
191+
The image on the left is generated from a regular checkpoint, and the image on the right is from an inpaint checkpoint. You'll immediately notice the image on the left is not as clean, and you can still see the outline of the area the model is supposed to inpaint. The image on the right is much cleaner and the inpainted area appears more natural.
192+
193+
<hfoptions id="regular-specific">
194+
<hfoption id="runwayml/stable-diffusion-v1-5">
195+
196+
```py
197+
import torch
198+
from diffusers import AutoPipelineForInpainting
199+
from diffusers.utils import load_image, make_image_grid
200+
201+
pipeline = AutoPipelineForInpainting.from_pretrained(
202+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
203+
).to("cuda")
204+
pipeline.enable_model_cpu_offload()
205+
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
206+
pipeline.enable_xformers_memory_efficient_attention()
207+
208+
# load base and mask image
209+
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
210+
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
211+
212+
generator = torch.Generator("cuda").manual_seed(92)
213+
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
214+
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
215+
make_image_grid([init_image, image], rows=1, cols=2)
216+
```
217+
218+
</hfoption>
219+
<hfoption id="runwayml/stable-diffusion-inpainting">
220+
221+
```py
222+
import torch
223+
from diffusers import AutoPipelineForInpainting
224+
from diffusers.utils import load_image, make_image_grid
225+
226+
pipeline = AutoPipelineForInpainting.from_pretrained(
227+
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
228+
).to("cuda")
229+
pipeline.enable_model_cpu_offload()
230+
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
231+
pipeline.enable_xformers_memory_efficient_attention()
232+
233+
# load base and mask image
234+
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
235+
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
236+
237+
generator = torch.Generator("cuda").manual_seed(92)
238+
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
239+
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
240+
make_image_grid([init_image, image], rows=1, cols=2)
241+
```
242+
243+
</hfoption>
244+
</hfoptions>
245+
246+
<div class="flex gap-4">
247+
<div>
248+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-inpaint-specific.png"/>
249+
<figcaption class="mt-2 text-center text-sm text-gray-500">runwayml/stable-diffusion-v1-5</figcaption>
250+
</div>
251+
<div>
252+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-specific.png"/>
253+
<figcaption class="mt-2 text-center text-sm text-gray-500">runwayml/stable-diffusion-inpainting</figcaption>
254+
</div>
255+
</div>
256+
257+
However, for more basic tasks like erasing an object from an image (like the rocks in the road for example), a regular checkpoint yields pretty good results. There isn't as noticeable of difference between the regular and inpaint checkpoint.
258+
259+
<hfoptions id="inpaint">
260+
<hfoption id="runwayml/stable-diffusion-v1-5">
261+
262+
```py
263+
import torch
264+
from diffusers import AutoPipelineForInpainting
265+
from diffusers.utils import load_image, make_image_grid
266+
267+
pipeline = AutoPipelineForInpainting.from_pretrained(
268+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
269+
).to("cuda")
270+
pipeline.enable_model_cpu_offload()
271+
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
272+
pipeline.enable_xformers_memory_efficient_attention()
273+
274+
# load base and mask image
275+
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
276+
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png")
277+
278+
image = pipeline(prompt="road", image=init_image, mask_image=mask_image).images[0]
279+
make_image_grid([init_image, image], rows=1, cols=2)
280+
```
281+
282+
</hfoption>
283+
<hfoption id="runwayml/stable-diffusion-inpaint">
284+
285+
```py
286+
import torch
287+
from diffusers import AutoPipelineForInpainting
288+
from diffusers.utils import load_image, make_image_grid
289+
290+
pipeline = AutoPipelineForInpainting.from_pretrained(
291+
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
292+
).to("cuda")
293+
pipeline.enable_model_cpu_offload()
294+
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
295+
pipeline.enable_xformers_memory_efficient_attention()
296+
297+
# load base and mask image
298+
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
299+
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png")
300+
301+
image = pipeline(prompt="road", image=init_image, mask_image=mask_image).images[0]
302+
make_image_grid([init_image, image], rows=1, cols=2)
303+
```
304+
305+
</hfoption>
306+
</hfoptions>
307+
308+
<div class="flex gap-4">
309+
<div>
310+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/regular-inpaint-basic.png"/>
311+
<figcaption class="mt-2 text-center text-sm text-gray-500">runwayml/stable-diffusion-v1-5</figcaption>
312+
</div>
313+
<div>
314+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/specific-inpaint-basic.png"/>
315+
<figcaption class="mt-2 text-center text-sm text-gray-500">runwayml/stable-diffusion-inpainting</figcaption>
316+
</div>
317+
</div>
318+
319+
The trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.
320+
321+
If preserving the unmasked area is important for your task, you can use the code below to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
322+
323+
```py
324+
import PIL
325+
import numpy as np
326+
import torch
327+
328+
from diffusers import AutoPipelineForInpainting
329+
from diffusers.utils import load_image, make_image_grid
330+
331+
device = "cuda"
332+
pipeline = AutoPipelineForInpainting.from_pretrained(
333+
"runwayml/stable-diffusion-inpainting",
334+
torch_dtype=torch.float16,
335+
)
336+
pipeline = pipeline.to(device)
337+
338+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
339+
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
340+
341+
init_image = load_image(img_url).resize((512, 512))
342+
mask_image = load_image(mask_url).resize((512, 512))
343+
344+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
345+
repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
346+
repainted_image.save("repainted_image.png")
347+
348+
# Convert mask to grayscale NumPy array
349+
mask_image_arr = np.array(mask_image.convert("L"))
350+
# Add a channel dimension to the end of the grayscale mask
351+
mask_image_arr = mask_image_arr[:, :, None]
352+
# Binarize the mask: 1s correspond to the pixels which are repainted
353+
mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
354+
mask_image_arr[mask_image_arr < 0.5] = 0
355+
mask_image_arr[mask_image_arr >= 0.5] = 1
356+
357+
# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
358+
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
359+
unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
360+
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
361+
make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
362+
```
363+
187364
## Configure pipeline parameters
188365

189366
Image features - like quality and "creativity" - are dependent on pipeline parameters. Knowing what these parameters do is important for getting the results you want. Let's take a look at the most important parameters and see how changing them affects the output.
@@ -309,51 +486,6 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
309486
</figure>
310487
</div>
311488

312-
## Preserve unmasked areas
313-
314-
The [`AutoPipelineForInpainting`] (and other inpainting pipelines) generally changes the unmasked parts of an image to create a more natural transition between the masked and unmasked region. If this behavior is undesirable, you can force the unmasked area to remain the same. However, forcing the unmasked portion of the image to remain the same may result in some unusual transitions between the unmasked and masked areas.
315-
316-
```py
317-
import PIL
318-
import numpy as np
319-
import torch
320-
321-
from diffusers import AutoPipelineForInpainting
322-
from diffusers.utils import load_image, make_image_grid
323-
324-
device = "cuda"
325-
pipeline = AutoPipelineForInpainting.from_pretrained(
326-
"runwayml/stable-diffusion-inpainting",
327-
torch_dtype=torch.float16,
328-
)
329-
pipeline = pipeline.to(device)
330-
331-
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
332-
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
333-
334-
init_image = load_image(img_url).resize((512, 512))
335-
mask_image = load_image(mask_url).resize((512, 512))
336-
337-
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
338-
repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
339-
repainted_image.save("repainted_image.png")
340-
341-
# Convert mask to grayscale NumPy array
342-
mask_image_arr = np.array(mask_image.convert("L"))
343-
# Add a channel dimension to the end of the grayscale mask
344-
mask_image_arr = mask_image_arr[:, :, None]
345-
# Binarize the mask: 1s correspond to the pixels which are repainted
346-
mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
347-
mask_image_arr[mask_image_arr < 0.5] = 0
348-
mask_image_arr[mask_image_arr >= 0.5] = 1
349-
350-
# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
351-
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
352-
unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
353-
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
354-
make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
355-
```
356-
357489
## Chained inpainting pipelines
358490

359491
[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.

0 commit comments

Comments
 (0)