Skip to content

Commit 22a3176

Browse files
patrickvonplatenwilliambermanyiyixuxupcuenca
authored
[Docs] Weight prompting using compel (huggingface#2574)
* add docs * correct * finish * Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Co-authored-by: YiYi Xu <[email protected]> * update deps table * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Will Berman <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent f0b661b commit 22a3176

File tree

8 files changed

+171
-2
lines changed

8 files changed

+171
-2
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
title: How to contribute a Pipeline
4949
- local: using-diffusers/using_safetensors
5050
title: Using safetensors
51+
- local: using-diffusers/weighted_prompts
52+
title: Weighting Prompts
5153
title: Pipelines for Inference
5254
- sections:
5355
- local: using-diffusers/rl

docs/source/en/using-diffusers/controlling_generation.mdx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
3636
8. [DreamBooth](#dreambooth)
3737
9. [Textual Inversion](#textual-inversion)
3838
10. [ControlNet](#controlnet)
39+
11. [Prompt Weighting](#prompt-weighting)
3940

4041
## Instruct Pix2Pix
4142

@@ -158,3 +159,9 @@ depth maps, and semantic segmentations.
158159

159160
See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it.
160161

162+
## Prompt Weighting
163+
164+
Prompt weighting is a simple technique that puts more attention weight on certain parts of the text
165+
input.
166+
167+
For a more in-detail explanation and examples, see [here](../using-diffusers/weighted_prompts).
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Weighting prompts
14+
15+
Text-guided diffusion models generate images based on a given text prompt. The text prompt
16+
can include multiple concepts that the model should generate and it's often desirable to weight
17+
certain parts of the prompt more or less.
18+
19+
Diffusion models work by conditioning the cross attention layers of the diffusion model with contextualized text embeddings (see the [Stable Diffusion Guide for more information](../stable-diffusion)).
20+
Thus a simple way to emphasize (or de-emphasize) certain parts of the prompt is by increasing or reducing the scale of the text embedding vector that corresponds to the relevant part of the prompt.
21+
This is called "prompt-weighting" and has been a highly demanded feature by the community (see issue [here](https://github.com/huggingface/diffusers/issues/2431)).
22+
23+
## How to do prompt-weighting in Diffusers
24+
25+
We believe the role of `diffusers` is to be a toolbox that provides essential features that enable other projects, such as [InvokeAI](https://github.com/invoke-ai/InvokeAI) or [diffuzers](https://github.com/abhishekkrthakur/diffuzers), to build powerful UIs. In order to support arbitrary methods to manipulate prompts, `diffusers` exposes a [`prompt_embeds`](https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) function argument to many pipelines such as [`StableDiffusionPipeline`], allowing to directly pass the "prompt-weighted"/scaled text embeddings to the pipeline.
26+
27+
The [compel library](https://github.com/damian0815/compel) provides an easy way to emphasize or de-emphasize portions of the prompt for you. We strongly recommend it instead of preparing the embeddings yourself.
28+
29+
Let's look at a simple example. Imagine you want to generate an image of `"a red cat playing with a ball"` as
30+
follows:
31+
32+
```py
33+
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
34+
35+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
36+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37+
38+
prompt = "a red cat playing with a ball"
39+
40+
generator = torch.Generator(device="cpu").manual_seed(33)
41+
42+
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
43+
image
44+
```
45+
46+
This gives you:
47+
48+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png)
49+
50+
As you can see, there is no "ball" in the image. Let's emphasize this part!
51+
52+
For this we should install the `compel` library:
53+
54+
```
55+
pip install compel
56+
```
57+
58+
and then create a `Compel` object:
59+
60+
```py
61+
from compel import Compel
62+
63+
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
64+
```
65+
66+
Now we emphasize the part "ball" with the `"++"` syntax:
67+
68+
```py
69+
prompt = "a red cat playing with a ball++"
70+
```
71+
72+
and instead of passing this to the pipeline directly, we have to process it using `compel_proc`:
73+
74+
```py
75+
prompt_embeds = compel_proc(prompt)
76+
```
77+
78+
Now we can pass `prompt_embeds` directly to the pipeline:
79+
80+
```py
81+
generator = torch.Generator(device="cpu").manual_seed(33)
82+
83+
images = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
84+
image
85+
```
86+
87+
We now get the following image which has a "ball"!
88+
89+
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_1.png)
90+
91+
Similarly, we de-emphasize parts of the sentence by using the `--` suffix for words, feel free to give it
92+
a try!
93+
94+
If your favorite pipeline does not have a `prompt_embeds` input, please make sure to open an issue, the
95+
diffusers team tries to be as responsive as possible.
96+
97+
Also, please check out the documentation of the [compel](https://github.com/damian0815/compel) library for
98+
more information.

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
_deps = [
8181
"Pillow", # keep the PIL.Image.Resampling deprecation away
8282
"accelerate>=0.11.0",
83+
"compel==0.1.8",
8384
"black~=23.1",
8485
"datasets",
8586
"filelock",
@@ -182,6 +183,7 @@ def run(self):
182183
extras["docs"] = deps_list("hf-doc-builder")
183184
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
184185
extras["test"] = deps_list(
186+
"compel",
185187
"datasets",
186188
"Jinja2",
187189
"k-diffusion",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
deps = {
55
"Pillow": "Pillow",
66
"accelerate": "accelerate>=0.11.0",
7+
"compel": "compel==0.1.8",
78
"black": "black~=23.1",
89
"datasets": "datasets",
910
"filelock": "filelock",

src/diffusers/utils/import_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@
232232
_tensorboard_available = False
233233

234234

235+
_compel_available = importlib.util.find_spec("compel")
236+
try:
237+
_compel_version = importlib_metadata.version("compel")
238+
logger.debug(f"Successfully imported compel version {_compel_version}")
239+
except importlib_metadata.PackageNotFoundError:
240+
_compel_available = False
241+
242+
235243
def is_torch_available():
236244
return _torch_available
237245

@@ -296,6 +304,10 @@ def is_tensorboard_available():
296304
return _tensorboard_available
297305

298306

307+
def is_compel_available():
308+
return _compel_available
309+
310+
299311
# docstyle-ignore
300312
FLAX_IMPORT_ERROR = """
301313
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -368,6 +380,12 @@ def is_tensorboard_available():
368380
install tensorboard`
369381
"""
370382

383+
384+
# docstyle-ignore
385+
COMPEL_IMPORT_ERROR = """
386+
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
387+
"""
388+
371389
BACKENDS_MAPPING = OrderedDict(
372390
[
373391
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
@@ -382,6 +400,7 @@ def is_tensorboard_available():
382400
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
383401
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
384402
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
403+
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
385404
]
386405
)
387406

src/diffusers/utils/testing_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import requests
1717
from packaging import version
1818

19-
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
19+
from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available
2020
from .logging import get_logger
2121

2222

@@ -175,6 +175,14 @@ def require_flax(test_case):
175175
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
176176

177177

178+
def require_compel(test_case):
179+
"""
180+
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
181+
the library is not installed.
182+
"""
183+
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
184+
185+
178186
def require_onnxruntime(test_case):
179187
"""
180188
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.

tests/test_pipelines.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@
4949
StableDiffusionPipeline,
5050
UNet2DConditionModel,
5151
UNet2DModel,
52+
UniPCMultistepScheduler,
5253
logging,
5354
)
5455
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
5556
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
56-
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
57+
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
5758

5859

5960
torch.backends.cuda.matmul.allow_tf32 = False
@@ -1058,6 +1059,37 @@ def test_from_flax_from_pt(self):
10581059

10591060
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
10601061

1062+
@require_compel
1063+
def test_weighted_prompts_compel(self):
1064+
from compel import Compel
1065+
1066+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
1067+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
1068+
pipe.enable_model_cpu_offload()
1069+
pipe.enable_attention_slicing()
1070+
1071+
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
1072+
1073+
prompt = "a red cat playing with a ball{}"
1074+
1075+
prompts = [prompt.format(s) for s in ["", "++", "--"]]
1076+
1077+
prompt_embeds = compel(prompts)
1078+
1079+
generator = [torch.Generator(device="cpu").manual_seed(33) for _ in range(prompt_embeds.shape[0])]
1080+
1081+
images = pipe(
1082+
prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20, output_type="numpy"
1083+
).images
1084+
1085+
for i, image in enumerate(images):
1086+
expected_image = load_numpy(
1087+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1088+
f"/compel/forest_{i}.npy"
1089+
)
1090+
1091+
assert np.abs(image - expected_image).max() < 1e-3
1092+
10611093

10621094
@nightly
10631095
@require_torch_gpu

0 commit comments

Comments
 (0)