Skip to content

Commit fdb05f5

Browse files
authored
Official callbacks (huggingface#7761)
1 parent 98ba18b commit fdb05f5

17 files changed

+400
-94
lines changed

docs/source/en/using-diffusers/callback.md

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,74 @@ The denoising loop of a pipeline can be modified with custom defined functions u
1919
2020
This guide will demonstrate how callbacks work by a few features you can implement with them.
2121

22+
## Official callbacks
23+
24+
We provide a list of callbacks you can plug into an existing pipeline and modify the denoising loop. This is the current list of official callbacks:
25+
26+
- `SDCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SD 1.5 pipelines, including text-to-image, image-to-image, inpaint, and controlnet.
27+
- `SDXLCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SDXL pipelines, including text-to-image, image-to-image, inpaint, and controlnet.
28+
- `IPAdapterScaleCutoffCallback`: Disables the IP Adapter after a certain number of steps for all pipelines supporting IP-Adapter.
29+
30+
> [!TIP]
31+
> If you want to add a new official callback, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) or [submit a PR](https://huggingface.co/docs/diffusers/main/en/conceptual/contribution#how-to-open-a-pr).
32+
33+
To set up a callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments
34+
35+
- `cutoff_step_ratio`: Float number with the ratio of the steps.
36+
- `cutoff_step_index`: Integer number with the exact number of the step.
37+
38+
```python
39+
import torch
40+
41+
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
42+
from diffusers.callbacks import SDXLCFGCutoffCallback
43+
44+
45+
callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
46+
# can also be used with cutoff_step_index
47+
# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)
48+
49+
pipeline = StableDiffusionXLPipeline.from_pretrained(
50+
"stabilityai/stable-diffusion-xl-base-1.0",
51+
torch_dtype=torch.float16,
52+
variant="fp16",
53+
).to("cuda")
54+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)
55+
56+
prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution"
57+
58+
generator = torch.Generator(device="cpu").manual_seed(2628670641)
59+
60+
out = pipeline(
61+
prompt=prompt,
62+
negative_prompt="",
63+
guidance_scale=6.5,
64+
num_inference_steps=25,
65+
generator=generator,
66+
callback_on_step_end=callback,
67+
)
68+
69+
out.images[0].save("official_callback.png")
70+
```
71+
72+
<div class="flex gap-4">
73+
<div>
74+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/without_cfg_callback.png" alt="generated image of a sports car at the road" />
75+
<figcaption class="mt-2 text-center text-sm text-gray-500">without SDXLCFGCutoffCallback</figcaption>
76+
</div>
77+
<div>
78+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a a sports car at the road with cfg callback" />
79+
<figcaption class="mt-2 text-center text-sm text-gray-500">with SDXLCFGCutoffCallback</figcaption>
80+
</div>
81+
</div>
82+
2283
## Dynamic classifier-free guidance
2384

2485
Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments:
2586

26-
* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
27-
* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
28-
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
87+
- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
88+
- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
89+
- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
2990

3091
Your callback function should look something like this:
3192

src/diffusers/callbacks.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from typing import Any, Dict, List
2+
3+
from .configuration_utils import ConfigMixin, register_to_config
4+
from .utils import CONFIG_NAME
5+
6+
7+
class PipelineCallback(ConfigMixin):
8+
"""
9+
Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10+
custom callbacks and ensures that all callbacks have a consistent interface.
11+
12+
Please implement the following:
13+
`tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14+
include
15+
variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16+
`callback_fn`: This method defines the core functionality of your callback.
17+
"""
18+
19+
config_name = CONFIG_NAME
20+
21+
@register_to_config
22+
def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23+
super().__init__()
24+
25+
if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26+
cutoff_step_ratio is not None and cutoff_step_index is not None
27+
):
28+
raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29+
30+
if cutoff_step_ratio is not None and (
31+
not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32+
):
33+
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34+
35+
@property
36+
def tensor_inputs(self) -> List[str]:
37+
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38+
39+
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40+
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41+
42+
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43+
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44+
45+
46+
class MultiPipelineCallbacks:
47+
"""
48+
This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49+
provides a unified interface for calling all of them.
50+
"""
51+
52+
def __init__(self, callbacks: List[PipelineCallback]):
53+
self.callbacks = callbacks
54+
55+
@property
56+
def tensor_inputs(self) -> List[str]:
57+
return [input for callback in self.callbacks for input in callback.tensor_inputs]
58+
59+
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60+
"""
61+
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62+
"""
63+
for callback in self.callbacks:
64+
callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65+
66+
return callback_kwargs
67+
68+
69+
class SDCFGCutoffCallback(PipelineCallback):
70+
"""
71+
Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72+
`cutoff_step_index`), this callback will disable the CFG.
73+
74+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75+
"""
76+
77+
tensor_inputs = ["prompt_embeds"]
78+
79+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80+
cutoff_step_ratio = self.config.cutoff_step_ratio
81+
cutoff_step_index = self.config.cutoff_step_index
82+
83+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84+
cutoff_step = (
85+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86+
)
87+
88+
if step_index == cutoff_step:
89+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91+
92+
pipeline._guidance_scale = 0.0
93+
94+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95+
return callback_kwargs
96+
97+
98+
class SDXLCFGCutoffCallback(PipelineCallback):
99+
"""
100+
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101+
`cutoff_step_index`), this callback will disable the CFG.
102+
103+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104+
"""
105+
106+
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
107+
108+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109+
cutoff_step_ratio = self.config.cutoff_step_ratio
110+
cutoff_step_index = self.config.cutoff_step_index
111+
112+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
113+
cutoff_step = (
114+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
115+
)
116+
117+
if step_index == cutoff_step:
118+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
119+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
120+
121+
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
122+
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
123+
124+
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
125+
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
126+
127+
pipeline._guidance_scale = 0.0
128+
129+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
132+
return callback_kwargs
133+
134+
135+
class IPAdapterScaleCutoffCallback(PipelineCallback):
136+
"""
137+
Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
138+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
139+
140+
Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
141+
"""
142+
143+
tensor_inputs = []
144+
145+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
146+
cutoff_step_ratio = self.config.cutoff_step_ratio
147+
cutoff_step_index = self.config.cutoff_step_index
148+
149+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
150+
cutoff_step = (
151+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
152+
)
153+
154+
if step_index == cutoff_step:
155+
pipeline.set_ip_adapter_scale(0.0)
156+
return callback_kwargs

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn.functional as F
2323
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2424

25+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2526
from ...image_processor import PipelineImageInput, VaeImageProcessor
2627
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2728
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
@@ -926,7 +927,9 @@ def __call__(
926927
control_guidance_start: Union[float, List[float]] = 0.0,
927928
control_guidance_end: Union[float, List[float]] = 1.0,
928929
clip_skip: Optional[int] = None,
929-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
930+
callback_on_step_end: Optional[
931+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
932+
] = None,
930933
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
931934
**kwargs,
932935
):
@@ -1019,11 +1022,11 @@ def __call__(
10191022
clip_skip (`int`, *optional*):
10201023
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
10211024
the output of the pre-final layer will be used for computing the prompt embeddings.
1022-
callback_on_step_end (`Callable`, *optional*):
1023-
A function that calls at the end of each denoising steps during the inference. The function is called
1024-
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1025-
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1026-
`callback_on_step_end_tensor_inputs`.
1025+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1026+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1027+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1028+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1029+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
10271030
callback_on_step_end_tensor_inputs (`List`, *optional*):
10281031
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
10291032
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
@@ -1055,6 +1058,9 @@ def __call__(
10551058
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
10561059
)
10571060

1061+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1062+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1063+
10581064
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
10591065

10601066
# align format for control guidance

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn.functional as F
2222
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2323

24+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2425
from ...image_processor import PipelineImageInput, VaeImageProcessor
2526
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2627
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
@@ -917,7 +918,9 @@ def __call__(
917918
control_guidance_start: Union[float, List[float]] = 0.0,
918919
control_guidance_end: Union[float, List[float]] = 1.0,
919920
clip_skip: Optional[int] = None,
920-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
921+
callback_on_step_end: Optional[
922+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
923+
] = None,
921924
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
922925
**kwargs,
923926
):
@@ -1004,11 +1007,11 @@ def __call__(
10041007
clip_skip (`int`, *optional*):
10051008
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
10061009
the output of the pre-final layer will be used for computing the prompt embeddings.
1007-
callback_on_step_end (`Callable`, *optional*):
1008-
A function that calls at the end of each denoising steps during the inference. The function is called
1009-
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1010-
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1011-
`callback_on_step_end_tensor_inputs`.
1010+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1011+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1012+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1013+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1014+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
10121015
callback_on_step_end_tensor_inputs (`List`, *optional*):
10131016
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
10141017
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
@@ -1040,6 +1043,9 @@ def __call__(
10401043
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
10411044
)
10421045

1046+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1047+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1048+
10431049
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
10441050

10451051
# align format for control guidance

0 commit comments

Comments
 (0)