Skip to content

Commit f6df16c

Browse files
authored
[docs] Community tips (huggingface#7137)
* tips * feedback * callback only
1 parent b24f783 commit f6df16c

File tree

1 file changed

+112
-31
lines changed

1 file changed

+112
-31
lines changed

docs/source/en/using-diffusers/callback.md

Lines changed: 112 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ specific language governing permissions and limitations under the License.
1212

1313
# Pipeline callbacks
1414

15-
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. This can be really useful for *dynamically* adjusting certain pipeline attributes, or modifying tensor variables. The flexibility of callbacks opens up some interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale.
15+
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
1616

17-
This guide will show you how to use the `callback_on_step_end` parameter to disable classifier-free guidance (CFG) after 40% of the inference steps to save compute with minimal cost to performance.
17+
> [!TIP]
18+
> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
1819
19-
The callback function should have the following arguments:
20+
This guide will demonstrate how callbacks work by a few features you can implement with them.
2021

21-
* `pipe` (or the pipeline instance) provides access to useful properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipe._guidance_scale=0.0`.
22+
## Dynamic classifier-free guidance
23+
24+
Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments:
25+
26+
* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
2227
* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
2328
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
2429

@@ -27,12 +32,12 @@ Your callback function should look something like this:
2732
```python
2833
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
2934
# adjust the batch_size of prompt_embeds according to guidance_scale
30-
if step_index == int(pipe.num_timesteps * 0.4):
35+
if step_index == int(pipeline.num_timesteps * 0.4):
3136
prompt_embeds = callback_kwargs["prompt_embeds"]
3237
prompt_embeds = prompt_embeds.chunk(2)[-1]
3338

3439
# update guidance_scale and prompt_embeds
35-
pipe._guidance_scale = 0.0
40+
pipeline._guidance_scale = 0.0
3641
callback_kwargs["prompt_embeds"] = prompt_embeds
3742
return callback_kwargs
3843
```
@@ -43,58 +48,134 @@ Now, you can pass the callback function to the `callback_on_step_end` parameter
4348
import torch
4449
from diffusers import StableDiffusionPipeline
4550

46-
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
47-
pipe = pipe.to("cuda")
51+
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
52+
pipeline = pipeline.to("cuda")
4853

4954
prompt = "a photo of an astronaut riding a horse on mars"
5055

5156
generator = torch.Generator(device="cuda").manual_seed(1)
52-
out = pipe(prompt, generator=generator, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
57+
out = pipeline(
58+
prompt,
59+
generator=generator,
60+
callback_on_step_end=callback_dynamic_cfg,
61+
callback_on_step_end_tensor_inputs=['prompt_embeds']
62+
)
5363

5464
out.images[0].save("out_custom_cfg.png")
5565
```
5666

57-
The callback function is executed at the end of each denoising step, and modifies the pipeline attributes and tensor variables for the next denoising step.
58-
59-
With callbacks, you can implement features such as dynamic CFG without having to modify the underlying code at all!
60-
61-
<Tip>
62-
63-
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
64-
65-
</Tip>
66-
6767
## Interrupt the diffusion process
6868

69-
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
69+
> [!TIP]
70+
> The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
7071
71-
<Tip>
72+
Stopping the diffusion process early is useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
7273

73-
The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
74-
75-
</Tip>
76-
77-
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
74+
This callback function should take the following arguments: `pipeline`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
7875

7976
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
8077

8178
```python
8279
from diffusers import StableDiffusionPipeline
8380

84-
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
85-
pipe.enable_model_cpu_offload()
81+
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
82+
pipeline.enable_model_cpu_offload()
8683
num_inference_steps = 50
8784

88-
def interrupt_callback(pipe, i, t, callback_kwargs):
85+
def interrupt_callback(pipeline, i, t, callback_kwargs):
8986
stop_idx = 10
9087
if i == stop_idx:
91-
pipe._interrupt = True
88+
pipeline._interrupt = True
9289

9390
return callback_kwargs
9491

95-
pipe(
92+
pipeline(
9693
"A photo of a cat",
9794
num_inference_steps=num_inference_steps,
9895
callback_on_step_end=interrupt_callback,
9996
)
10097
```
98+
99+
## Display image after each generation step
100+
101+
> [!TIP]
102+
> This tip was contributed by [asomoza](https://github.com/asomoza).
103+
104+
Display an image after each generation step by accessing and converting the latents after each step into an image. The latent space is compressed to 128x128, so the images are also 128x128 which is useful for a quick preview.
105+
106+
1. Use the function below to convert the SDXL latents (4 channels) to RGB tensors (3 channels) as explained in the [Explaining the SDXL latent space](https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space) blog post.
107+
108+
```py
109+
def latents_to_rgb(latents):
110+
weights = (
111+
(60, -60, 25, -70),
112+
(60, -5, 15, -50),
113+
(60, 10, -5, -35)
114+
)
115+
116+
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
117+
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
118+
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
119+
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
120+
image_array = image_array.transpose(1, 2, 0)
121+
122+
return Image.fromarray(image_array)
123+
```
124+
125+
2. Create a function to decode and save the latents into an image.
126+
127+
```py
128+
def decode_tensors(pipe, step, timestep, callback_kwargs):
129+
latents = callback_kwargs["latents"]
130+
131+
image = latents_to_rgb(latents)
132+
image.save(f"{step}.png")
133+
134+
return callback_kwargs
135+
```
136+
137+
3. Pass the `decode_tensors` function to the `callback_on_step_end` parameter to decode the tensors after each step. You also need to specify what you want to modify in the `callback_on_step_end_tensor_inputs` parameter, which in this case are the latents.
138+
139+
```py
140+
from diffusers import AutoPipelineForText2Image
141+
import torch
142+
from PIL import Image
143+
144+
pipeline = AutoPipelineForText2Image.from_pretrained(
145+
"stabilityai/stable-diffusion-xl-base-1.0",
146+
torch_dtype=torch.float16,
147+
variant="fp16",
148+
use_safetensors=True
149+
).to("cuda")
150+
151+
image = pipe(
152+
prompt = "A croissant shaped like a cute bear."
153+
negative_prompt = "Deformed, ugly, bad anatomy"
154+
callback_on_step_end=decode_tensors,
155+
callback_on_step_end_tensor_inputs=["latents"],
156+
).images[0]
157+
```
158+
159+
<div class="flex gap-4 justify-center">
160+
<div>
161+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_0.png"/>
162+
<figcaption class="mt-2 text-center text-sm text-gray-500">step 0</figcaption>
163+
</div>
164+
<div>
165+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_19.png"/>
166+
<figcaption class="mt-2 text-center text-sm text-gray-500">step 19
167+
</figcaption>
168+
</div>
169+
<div>
170+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_29.png"/>
171+
<figcaption class="mt-2 text-center text-sm text-gray-500">step 29</figcaption>
172+
</div>
173+
<div>
174+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_39.png"/>
175+
<figcaption class="mt-2 text-center text-sm text-gray-500">step 39</figcaption>
176+
</div>
177+
<div>
178+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_49.png"/>
179+
<figcaption class="mt-2 text-center text-sm text-gray-500">step 49</figcaption>
180+
</div>
181+
</div>

0 commit comments

Comments
 (0)