Skip to content

Commit 38466c3

Browse files
authored
Add GLIGEN Text Image implementation (huggingface#4777)
* Add GLIGEN Text Image implementation * add style transfer from image * fix check_repository_consistency * add convert script GLIGEN model to Diffusers * rename attention type * fix style code * remove PositionNetTextImage * Revert "fix check_repository_consistency" This reverts commit 15f098c. * change attention type name * update docs for GLIGEN * change examples with hf-document-image * fix style * add CLIPImageProjection for GLIGEN * Add new encode_prompt, load project matrix in pipe init * move CLIPImageProjection to stable_diffusion * add comment
1 parent 5f740d0 commit 38466c3

File tree

14 files changed

+2014
-39
lines changed

14 files changed

+2014
-39
lines changed

docs/source/en/api/pipelines/stable_diffusion/gligen.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
1212

1313
# GLIGEN (Grounded Language-to-Image Generation)
1414

15-
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes, if input images are given, this pipeline can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
15+
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
1616

1717
The abstract from the [paper](https://huggingface.co/papers/2301.07093) is:
1818

@@ -26,7 +26,7 @@ If you want to use one of the official checkpoints for a task, explore the [glig
2626

2727
</Tip>
2828

29-
This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful).
29+
[`StableDiffusionGLIGENPipeline`] was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful) and [`StableDiffusionGLIGENTextImagePipeline`] was contributed by [Nguyễn Công Tú Anh](https://github.com/tuanh123789).
3030

3131
## StableDiffusionGLIGENPipeline
3232

@@ -41,6 +41,19 @@ This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhi
4141
- prepare_latents
4242
- enable_fuser
4343

44+
## StableDiffusionGLIGENTextImagePipeline
45+
46+
[[autodoc]] StableDiffusionGLIGENTextImagePipeline
47+
- all
48+
- __call__
49+
- enable_vae_slicing
50+
- disable_vae_slicing
51+
- enable_vae_tiling
52+
- disable_vae_tiling
53+
- enable_model_cpu_offload
54+
- prepare_latents
55+
- enable_fuser
56+
4457
## StableDiffusionPipelineOutput
4558

4659
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput

scripts/convert_gligen_to_diffusers.py

Lines changed: 587 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
AutoPipelineForImage2Image,
6767
AutoPipelineForInpainting,
6868
AutoPipelineForText2Image,
69+
CLIPImageProjection,
6970
ConsistencyModelPipeline,
7071
DanceDiffusionPipeline,
7172
DDIMPipeline,
@@ -176,6 +177,7 @@
176177
StableDiffusionDepth2ImgPipeline,
177178
StableDiffusionDiffEditPipeline,
178179
StableDiffusionGLIGENPipeline,
180+
StableDiffusionGLIGENTextImagePipeline,
179181
StableDiffusionImageVariationPipeline,
180182
StableDiffusionImg2ImgPipeline,
181183
StableDiffusionInpaintPipeline,

src/diffusers/models/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
155155

156156
# 4. Fuser
157-
if attention_type == "gated":
157+
if attention_type == "gated" or attention_type == "gated-text-image":
158158
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
159159

160160
# let chunk size default to None

src/diffusers/models/embeddings.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def __call__(self, x):
563563

564564

565565
class PositionNet(nn.Module):
566-
def __init__(self, positive_len, out_dim, fourier_freqs=8):
566+
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
567567
super().__init__()
568568
self.positive_len = positive_len
569569
self.out_dim = out_dim
@@ -573,30 +573,83 @@ def __init__(self, positive_len, out_dim, fourier_freqs=8):
573573

574574
if isinstance(out_dim, tuple):
575575
out_dim = out_dim[0]
576-
self.linears = nn.Sequential(
577-
nn.Linear(self.positive_len + self.position_dim, 512),
578-
nn.SiLU(),
579-
nn.Linear(512, 512),
580-
nn.SiLU(),
581-
nn.Linear(512, out_dim),
582-
)
583576

584-
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
577+
if feature_type == "text-only":
578+
self.linears = nn.Sequential(
579+
nn.Linear(self.positive_len + self.position_dim, 512),
580+
nn.SiLU(),
581+
nn.Linear(512, 512),
582+
nn.SiLU(),
583+
nn.Linear(512, out_dim),
584+
)
585+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
586+
587+
elif feature_type == "text-image":
588+
self.linears_text = nn.Sequential(
589+
nn.Linear(self.positive_len + self.position_dim, 512),
590+
nn.SiLU(),
591+
nn.Linear(512, 512),
592+
nn.SiLU(),
593+
nn.Linear(512, out_dim),
594+
)
595+
self.linears_image = nn.Sequential(
596+
nn.Linear(self.positive_len + self.position_dim, 512),
597+
nn.SiLU(),
598+
nn.Linear(512, 512),
599+
nn.SiLU(),
600+
nn.Linear(512, out_dim),
601+
)
602+
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
603+
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
604+
585605
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
586606

587-
def forward(self, boxes, masks, positive_embeddings):
607+
def forward(
608+
self,
609+
boxes,
610+
masks,
611+
positive_embeddings=None,
612+
phrases_masks=None,
613+
image_masks=None,
614+
phrases_embeddings=None,
615+
image_embeddings=None,
616+
):
588617
masks = masks.unsqueeze(-1)
589618

590619
# embedding position (it may includes padding as placeholder)
591620
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
592621

593622
# learnable null embedding
594-
positive_null = self.null_positive_feature.view(1, 1, -1)
595623
xyxy_null = self.null_position_feature.view(1, 1, -1)
596624

597625
# replace padding with learnable null embedding
598-
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
599626
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
600627

601-
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
628+
# positionet with text only information
629+
if positive_embeddings is not None:
630+
# learnable null embedding
631+
positive_null = self.null_positive_feature.view(1, 1, -1)
632+
633+
# replace padding with learnable null embedding
634+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
635+
636+
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
637+
638+
# positionet with text and image infomation
639+
else:
640+
phrases_masks = phrases_masks.unsqueeze(-1)
641+
image_masks = image_masks.unsqueeze(-1)
642+
643+
# learnable null embedding
644+
text_null = self.null_text_feature.view(1, 1, -1)
645+
image_null = self.null_image_feature.view(1, 1, -1)
646+
647+
# replace padding with learnable null embedding
648+
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
649+
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
650+
651+
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
652+
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
653+
objs = torch.cat([objs_text, objs_image], dim=1)
654+
602655
return objs

src/diffusers/models/unet_2d_condition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,13 +565,17 @@ def __init__(
565565
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
566566
)
567567

568-
if attention_type == "gated":
568+
if attention_type in ["gated", "gated-text-image"]:
569569
positive_len = 768
570570
if isinstance(cross_attention_dim, int):
571571
positive_len = cross_attention_dim
572572
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
573573
positive_len = cross_attention_dim[0]
574-
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
574+
575+
feature_type = "text-only" if attention_type == "gated" else "text-image"
576+
self.position_net = PositionNet(
577+
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
578+
)
575579

576580
@property
577581
def attn_processors(self) -> Dict[str, AttentionProcessor]:

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
StableDiffusionDepth2ImgPipeline,
9595
StableDiffusionDiffEditPipeline,
9696
StableDiffusionGLIGENPipeline,
97+
StableDiffusionGLIGENTextImagePipeline,
9798
StableDiffusionImageVariationPipeline,
9899
StableDiffusionImg2ImgPipeline,
99100
StableDiffusionInpaintPipeline,
@@ -111,6 +112,7 @@
111112
StableUnCLIPImg2ImgPipeline,
112113
StableUnCLIPPipeline,
113114
)
115+
from .stable_diffusion.clip_image_project_model import CLIPImageProjection
114116
from .stable_diffusion_safe import StableDiffusionPipelineSafe
115117
from .stable_diffusion_xl import (
116118
StableDiffusionXLImg2ImgPipeline,

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ class StableDiffusionPipelineOutput(BaseOutput):
4242
except OptionalDependencyNotAvailable:
4343
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
4444
else:
45+
from .clip_image_project_model import CLIPImageProjection
4546
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
4647
from .pipeline_stable_diffusion import StableDiffusionPipeline
4748
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
4849
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
50+
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
4951
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
5052
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
5153
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2023 The GLIGEN Authors and HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from torch import nn
16+
17+
from ...configuration_utils import ConfigMixin, register_to_config
18+
from ...models.modeling_utils import ModelMixin
19+
20+
21+
class CLIPImageProjection(ModelMixin, ConfigMixin):
22+
@register_to_config
23+
def __init__(self, hidden_size: int = 768):
24+
super().__init__()
25+
self.hidden_size = hidden_size
26+
self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
27+
28+
def forward(self, x):
29+
return self.project(x)

0 commit comments

Comments
 (0)