Skip to content

Commit 896c98a

Browse files
Add paint by example (huggingface#1533)
* add paint by example * mkae loading possibel * up * Update src/diffusers/models/attention.py * up * finalize weight structure * make example work * make it work * up * up * fix * del * add * update * Apply suggestions from code review * correct transformer 2d * finish * up * up * up * up * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review * up * finish Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 02d83c9 commit 896c98a

16 files changed

+1121
-34
lines changed

docs/source/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
title: "Latent Diffusion"
103103
- local: api/pipelines/latent_diffusion_uncond
104104
title: "Unconditional Latent Diffusion"
105+
- local: api/pipelines/paint_by_example
106+
title: "PaintByExample"
105107
- local: api/pipelines/pndm
106108
title: "PNDM"
107109
- local: api/pipelines/score_sde_ve

docs/source/api/pipelines/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ available a colab notebook to directly try them out.
5353
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
5454
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
5555
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
56+
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
5657
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
5758
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
5859
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# PaintByExample
14+
15+
## Overview
16+
17+
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen
18+
19+
The abstract of the paper is the following:
20+
21+
*Language-guided image editing has achieved great success recently. In this paper, for the first time, we investigate exemplar-guided image editing for more precise control. We achieve this goal by leveraging self-supervised training to disentangle and re-organize the source image and the exemplar. However, the naive approach will cause obvious fusing artifacts. We carefully analyze it and propose an information bottleneck and strong augmentations to avoid the trivial solution of directly copying and pasting the exemplar image. Meanwhile, to ensure the controllability of the editing process, we design an arbitrary shape mask for the exemplar image and leverage the classifier-free guidance to increase the similarity to the exemplar image. The whole framework involves a single forward of the diffusion model without any iterative optimization. We demonstrate that our method achieves an impressive performance and enables controllable editing on in-the-wild images with high fidelity.*
22+
23+
The original codebase can be found [here](https://github.com/Fantasy-Studio/Paint-by-Example).
24+
25+
## Available Pipelines:
26+
27+
| Pipeline | Tasks | Colab
28+
|---|---|:---:|
29+
| [pipeline_paint_by_example.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py) | *Image-Guided Image Painting* | - |
30+
31+
## Tips
32+
33+
- PaintByExample is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint has been warm-started from the [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and with the objective to inpaint partly masked images conditioned on example / reference images
34+
- To quickly demo *PaintByExample*, please have a look at [this demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example)
35+
- You can run the following code snippet as an example:
36+
37+
38+
```python
39+
# !pip install diffusers transformers
40+
41+
import PIL
42+
import requests
43+
import torch
44+
from io import BytesIO
45+
from diffusers import DiffusionPipeline
46+
47+
48+
def download_image(url):
49+
response = requests.get(url)
50+
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
51+
52+
53+
img_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png"
54+
mask_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png"
55+
example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg"
56+
57+
init_image = download_image(img_url).resize((512, 512))
58+
mask_image = download_image(mask_url).resize((512, 512))
59+
example_image = download_image(example_url).resize((512, 512))
60+
61+
pipe = DiffusionPipeline.from_pretrained(
62+
"Fantasy-Studio/Paint-by-Example",
63+
torch_dtype=torch.float16,
64+
)
65+
pipe = pipe.to("cuda")
66+
67+
image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0]
68+
image
69+
```
70+
71+
## PaintByExamplePipeline
72+
[[autodoc]] pipelines.paint_by_example.pipeline_paint_by_example.PaintByExamplePipeline
73+
- __call__

docs/source/index.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ available a colab notebook to directly try them out.
4343
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
4444
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
4545
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
46+
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
4647
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
4748
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
4849
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |

scripts/convert_original_stable_diffusion_to_diffusers.py

+102-4
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
UNet2DConditionModel,
4242
)
4343
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
44+
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
4445
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
45-
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
46+
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
4647

4748

4849
def shave_segments(path, n_shave_prefix_segments=1):
@@ -647,6 +648,73 @@ def convert_ldm_clip_checkpoint(checkpoint):
647648
return text_model
648649

649650

651+
def convert_paint_by_example_checkpoint(checkpoint):
652+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
653+
model = PaintByExampleImageEncoder(config)
654+
655+
keys = list(checkpoint.keys())
656+
657+
text_model_dict = {}
658+
659+
for key in keys:
660+
if key.startswith("cond_stage_model.transformer"):
661+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
662+
663+
# load clip vision
664+
model.model.load_state_dict(text_model_dict)
665+
666+
# load mapper
667+
keys_mapper = {
668+
k[len("cond_stage_model.mapper.res") :]: v
669+
for k, v in checkpoint.items()
670+
if k.startswith("cond_stage_model.mapper")
671+
}
672+
673+
MAPPING = {
674+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
675+
"attn.c_proj": ["attn1.to_out.0"],
676+
"ln_1": ["norm1"],
677+
"ln_2": ["norm3"],
678+
"mlp.c_fc": ["ff.net.0.proj"],
679+
"mlp.c_proj": ["ff.net.2"],
680+
}
681+
682+
mapped_weights = {}
683+
for key, value in keys_mapper.items():
684+
prefix = key[: len("blocks.i")]
685+
suffix = key.split(prefix)[-1].split(".")[-1]
686+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
687+
mapped_names = MAPPING[name]
688+
689+
num_splits = len(mapped_names)
690+
for i, mapped_name in enumerate(mapped_names):
691+
new_name = ".".join([prefix, mapped_name, suffix])
692+
shape = value.shape[0] // num_splits
693+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
694+
695+
model.mapper.load_state_dict(mapped_weights)
696+
697+
# load final layer norm
698+
model.final_layer_norm.load_state_dict(
699+
{
700+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
701+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
702+
}
703+
)
704+
705+
# load final proj
706+
model.proj_out.load_state_dict(
707+
{
708+
"bias": checkpoint["proj_out.bias"],
709+
"weight": checkpoint["proj_out.weight"],
710+
}
711+
)
712+
713+
# load uncond vector
714+
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
715+
return model
716+
717+
650718
def convert_open_clip_checkpoint(checkpoint):
651719
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
652720

@@ -676,12 +744,24 @@ def convert_open_clip_checkpoint(checkpoint):
676744
type=str,
677745
help="The YAML config file corresponding to the original architecture.",
678746
)
747+
parser.add_argument(
748+
"--num_in_channels",
749+
default=None,
750+
type=int,
751+
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
752+
)
679753
parser.add_argument(
680754
"--scheduler_type",
681755
default="pndm",
682756
type=str,
683757
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
684758
)
759+
parser.add_argument(
760+
"--pipeline_type",
761+
default=None,
762+
type=str,
763+
help="The pipeline type. If `None` pipeline will be automatically inferred.",
764+
)
685765
parser.add_argument(
686766
"--image_size",
687767
default=None,
@@ -737,6 +817,9 @@ def convert_open_clip_checkpoint(checkpoint):
737817

738818
original_config = OmegaConf.load(args.original_config_file)
739819

820+
if args.num_in_channels is not None:
821+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels
822+
740823
if (
741824
"parameterization" in original_config["model"]["params"]
742825
and original_config["model"]["params"]["parameterization"] == "v"
@@ -806,8 +889,11 @@ def convert_open_clip_checkpoint(checkpoint):
806889
vae.load_state_dict(converted_vae_checkpoint)
807890

808891
# Convert the text model.
809-
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
810-
if text_model_type == "FrozenOpenCLIPEmbedder":
892+
model_type = args.pipeline_type
893+
if model_type is None:
894+
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
895+
896+
if model_type == "FrozenOpenCLIPEmbedder":
811897
text_model = convert_open_clip_checkpoint(checkpoint)
812898
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
813899
pipe = StableDiffusionPipeline(
@@ -820,7 +906,19 @@ def convert_open_clip_checkpoint(checkpoint):
820906
feature_extractor=None,
821907
requires_safety_checker=False,
822908
)
823-
elif text_model_type == "FrozenCLIPEmbedder":
909+
elif model_type == "PaintByExample":
910+
vision_model = convert_paint_by_example_checkpoint(checkpoint)
911+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
912+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
913+
pipe = PaintByExamplePipeline(
914+
vae=vae,
915+
image_encoder=vision_model,
916+
unet=unet,
917+
scheduler=scheduler,
918+
safety_checker=None,
919+
feature_extractor=feature_extractor,
920+
)
921+
elif model_type == "FrozenCLIPEmbedder":
824922
text_model = convert_ldm_clip_checkpoint(checkpoint)
825923
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
826924
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

src/diffusers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
AltDiffusionPipeline,
7373
CycleDiffusionPipeline,
7474
LDMTextToImagePipeline,
75+
PaintByExamplePipeline,
7576
StableDiffusionImageVariationPipeline,
7677
StableDiffusionImg2ImgPipeline,
7778
StableDiffusionInpaintPipeline,

src/diffusers/models/attention.py

+56-24
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,9 @@ def __init__(
406406
):
407407
super().__init__()
408408
self.only_cross_attention = only_cross_attention
409+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
410+
411+
# 1. Self-Attn
409412
self.attn1 = CrossAttention(
410413
query_dim=dim,
411414
heads=num_attention_heads,
@@ -415,23 +418,28 @@ def __init__(
415418
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
416419
) # is a self-attention
417420
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
418-
self.attn2 = CrossAttention(
419-
query_dim=dim,
420-
cross_attention_dim=cross_attention_dim,
421-
heads=num_attention_heads,
422-
dim_head=attention_head_dim,
423-
dropout=dropout,
424-
bias=attention_bias,
425-
) # is self-attn if context is none
426421

427-
# layer norms
428-
self.use_ada_layer_norm = num_embeds_ada_norm is not None
429-
if self.use_ada_layer_norm:
430-
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431-
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
422+
# 2. Cross-Attn
423+
if cross_attention_dim is not None:
424+
self.attn2 = CrossAttention(
425+
query_dim=dim,
426+
cross_attention_dim=cross_attention_dim,
427+
heads=num_attention_heads,
428+
dim_head=attention_head_dim,
429+
dropout=dropout,
430+
bias=attention_bias,
431+
) # is self-attn if context is none
432432
else:
433-
self.norm1 = nn.LayerNorm(dim)
434-
self.norm2 = nn.LayerNorm(dim)
433+
self.attn2 = None
434+
435+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
436+
437+
if cross_attention_dim is not None:
438+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
439+
else:
440+
self.norm2 = None
441+
442+
# 3. Feed-forward
435443
self.norm3 = nn.LayerNorm(dim)
436444

437445
# if xformers is installed try to use memory_efficient_attention by default
@@ -481,11 +489,12 @@ def forward(self, hidden_states, context=None, timestep=None):
481489
else:
482490
hidden_states = self.attn1(norm_hidden_states) + hidden_states
483491

484-
# 2. Cross-Attention
485-
norm_hidden_states = (
486-
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
487-
)
488-
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
492+
if self.attn2 is not None:
493+
# 2. Cross-Attention
494+
norm_hidden_states = (
495+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
496+
)
497+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
489498

490499
# 3. Feed-forward
491500
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -666,14 +675,16 @@ def __init__(
666675
inner_dim = int(dim * mult)
667676
dim_out = dim_out if dim_out is not None else dim
668677

669-
if activation_fn == "geglu":
670-
geglu = GEGLU(dim, inner_dim)
678+
if activation_fn == "gelu":
679+
act_fn = GELU(dim, inner_dim)
680+
elif activation_fn == "geglu":
681+
act_fn = GEGLU(dim, inner_dim)
671682
elif activation_fn == "geglu-approximate":
672-
geglu = ApproximateGELU(dim, inner_dim)
683+
act_fn = ApproximateGELU(dim, inner_dim)
673684

674685
self.net = nn.ModuleList([])
675686
# project in
676-
self.net.append(geglu)
687+
self.net.append(act_fn)
677688
# project dropout
678689
self.net.append(nn.Dropout(dropout))
679690
# project out
@@ -685,6 +696,27 @@ def forward(self, hidden_states):
685696
return hidden_states
686697

687698

699+
class GELU(nn.Module):
700+
r"""
701+
GELU activation function
702+
"""
703+
704+
def __init__(self, dim_in: int, dim_out: int):
705+
super().__init__()
706+
self.proj = nn.Linear(dim_in, dim_out)
707+
708+
def gelu(self, gate):
709+
if gate.device.type != "mps":
710+
return F.gelu(gate)
711+
# mps: gelu is not implemented for float16
712+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
713+
714+
def forward(self, hidden_states):
715+
hidden_states = self.proj(hidden_states)
716+
hidden_states = self.gelu(hidden_states)
717+
return hidden_states
718+
719+
688720
# feedforward
689721
class GEGLU(nn.Module):
690722
r"""

src/diffusers/pipelines/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
if is_torch_available() and is_transformers_available():
2929
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
3030
from .latent_diffusion import LDMTextToImagePipeline
31+
from .paint_by_example import PaintByExamplePipeline
3132
from .stable_diffusion import (
3233
CycleDiffusionPipeline,
3334
StableDiffusionImageVariationPipeline,

0 commit comments

Comments
 (0)