Skip to content

Commit e7120ba

Browse files
authored
[UNet2DConditionModel] add gradient checkpointing (huggingface#461)
* add grad ckpt to downsample blocks * make it work * don't pass gradient_checkpointing to upsample block * add tests for UNet2DConditionModel * add test_gradient_checkpointing * add gradient_checkpointing for up and down blocks * add functions to enable and disable grad ckpt * remove the forward argument * better naming * make supports_gradient_checkpointing private
1 parent 534512b commit e7120ba

File tree

6 files changed

+219
-11
lines changed

6 files changed

+219
-11
lines changed

src/diffusers/modeling_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import os
18+
from functools import partial
1819
from typing import Callable, List, Optional, Tuple, Union
1920

2021
import torch
@@ -121,10 +122,42 @@ class ModelMixin(torch.nn.Module):
121122
"""
122123
config_name = CONFIG_NAME
123124
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
125+
_supports_gradient_checkpointing = False
124126

125127
def __init__(self):
126128
super().__init__()
127129

130+
@property
131+
def is_gradient_checkpointing(self) -> bool:
132+
"""
133+
Whether gradient checkpointing is activated for this model or not.
134+
135+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
136+
activations".
137+
"""
138+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
139+
140+
def enable_gradient_checkpointing(self):
141+
"""
142+
Activates gradient checkpointing for the current model.
143+
144+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
145+
activations".
146+
"""
147+
if not self._supports_gradient_checkpointing:
148+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
149+
self.apply(partial(self._set_gradient_checkpointing, value=True))
150+
151+
def disable_gradient_checkpointing(self):
152+
"""
153+
Deactivates gradient checkpointing for the current model.
154+
155+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
156+
activations".
157+
"""
158+
if self._supports_gradient_checkpointing:
159+
self.apply(partial(self._set_gradient_checkpointing, value=False))
160+
128161
def save_pretrained(
129162
self,
130163
save_directory: Union[str, os.PathLike],

src/diffusers/models/unet_2d_condition.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,21 @@
33

44
import torch
55
import torch.nn as nn
6+
import torch.utils.checkpoint
67

78
from ..configuration_utils import ConfigMixin, register_to_config
89
from ..modeling_utils import ModelMixin
910
from ..utils import BaseOutput
1011
from .embeddings import TimestepEmbedding, Timesteps
11-
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
12+
from .unet_blocks import (
13+
CrossAttnDownBlock2D,
14+
CrossAttnUpBlock2D,
15+
DownBlock2D,
16+
UNetMidBlock2DCrossAttn,
17+
UpBlock2D,
18+
get_down_block,
19+
get_up_block,
20+
)
1221

1322

1423
@dataclass
@@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
5463
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
5564
"""
5665

66+
_supports_gradient_checkpointing = True
67+
5768
@register_to_config
5869
def __init__(
5970
self,
@@ -188,6 +199,10 @@ def set_attention_slice(self, slice_size):
188199
if hasattr(block, "attentions") and block.attentions is not None:
189200
block.set_attention_slice(slice_size)
190201

202+
def _set_gradient_checkpointing(self, module, value=False):
203+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
204+
module.gradient_checkpointing = value
205+
191206
def forward(
192207
self,
193208
sample: torch.FloatTensor,
@@ -234,7 +249,9 @@ def forward(
234249
for downsample_block in self.down_blocks:
235250
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
236251
sample, res_samples = downsample_block(
237-
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
252+
hidden_states=sample,
253+
temb=emb,
254+
encoder_hidden_states=encoder_hidden_states,
238255
)
239256
else:
240257
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

src/diffusers/models/unet_blocks.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ def __init__(
527527
else:
528528
self.downsamplers = None
529529

530+
self.gradient_checkpointing = False
531+
530532
def set_attention_slice(self, slice_size):
531533
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
532534
raise ValueError(
@@ -546,8 +548,22 @@ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
546548
output_states = ()
547549

548550
for resnet, attn in zip(self.resnets, self.attentions):
549-
hidden_states = resnet(hidden_states, temb)
550-
hidden_states = attn(hidden_states, context=encoder_hidden_states)
551+
if self.training and self.gradient_checkpointing:
552+
553+
def create_custom_forward(module):
554+
def custom_forward(*inputs):
555+
return module(*inputs)
556+
557+
return custom_forward
558+
559+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
560+
hidden_states = torch.utils.checkpoint.checkpoint(
561+
create_custom_forward(attn), hidden_states, encoder_hidden_states
562+
)
563+
else:
564+
hidden_states = resnet(hidden_states, temb)
565+
hidden_states = attn(hidden_states, context=encoder_hidden_states)
566+
551567
output_states += (hidden_states,)
552568

553569
if self.downsamplers is not None:
@@ -609,11 +625,24 @@ def __init__(
609625
else:
610626
self.downsamplers = None
611627

628+
self.gradient_checkpointing = False
629+
612630
def forward(self, hidden_states, temb=None):
613631
output_states = ()
614632

615633
for resnet in self.resnets:
616-
hidden_states = resnet(hidden_states, temb)
634+
if self.training and self.gradient_checkpointing:
635+
636+
def create_custom_forward(module):
637+
def custom_forward(*inputs):
638+
return module(*inputs)
639+
640+
return custom_forward
641+
642+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
643+
else:
644+
hidden_states = resnet(hidden_states, temb)
645+
617646
output_states += (hidden_states,)
618647

619648
if self.downsamplers is not None:
@@ -1072,6 +1101,8 @@ def __init__(
10721101
else:
10731102
self.upsamplers = None
10741103

1104+
self.gradient_checkpointing = False
1105+
10751106
def set_attention_slice(self, slice_size):
10761107
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
10771108
raise ValueError(
@@ -1087,15 +1118,36 @@ def set_attention_slice(self, slice_size):
10871118
for attn in self.attentions:
10881119
attn._set_attention_slice(slice_size)
10891120

1090-
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
1121+
self.gradient_checkpointing = False
1122+
1123+
def forward(
1124+
self,
1125+
hidden_states,
1126+
res_hidden_states_tuple,
1127+
temb=None,
1128+
encoder_hidden_states=None,
1129+
):
10911130
for resnet, attn in zip(self.resnets, self.attentions):
10921131
# pop res hidden states
10931132
res_hidden_states = res_hidden_states_tuple[-1]
10941133
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
10951134
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
10961135

1097-
hidden_states = resnet(hidden_states, temb)
1098-
hidden_states = attn(hidden_states, context=encoder_hidden_states)
1136+
if self.training and self.gradient_checkpointing:
1137+
1138+
def create_custom_forward(module):
1139+
def custom_forward(*inputs):
1140+
return module(*inputs)
1141+
1142+
return custom_forward
1143+
1144+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1145+
hidden_states = torch.utils.checkpoint.checkpoint(
1146+
create_custom_forward(attn), hidden_states, encoder_hidden_states
1147+
)
1148+
else:
1149+
hidden_states = resnet(hidden_states, temb)
1150+
hidden_states = attn(hidden_states, context=encoder_hidden_states)
10991151

11001152
if self.upsamplers is not None:
11011153
for upsampler in self.upsamplers:
@@ -1150,14 +1202,26 @@ def __init__(
11501202
else:
11511203
self.upsamplers = None
11521204

1205+
self.gradient_checkpointing = False
1206+
11531207
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
11541208
for resnet in self.resnets:
11551209
# pop res hidden states
11561210
res_hidden_states = res_hidden_states_tuple[-1]
11571211
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
11581212
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11591213

1160-
hidden_states = resnet(hidden_states, temb)
1214+
if self.training and self.gradient_checkpointing:
1215+
1216+
def create_custom_forward(module):
1217+
def custom_forward(*inputs):
1218+
return module(*inputs)
1219+
1220+
return custom_forward
1221+
1222+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1223+
else:
1224+
hidden_states = resnet(hidden_states, temb)
11611225

11621226
if self.upsamplers is not None:
11631227
for upsampler in self.upsamplers:

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def forward(
478478
class LDMBertPreTrainedModel(PreTrainedModel):
479479
config_class = LDMBertConfig
480480
base_model_prefix = "model"
481-
supports_gradient_checkpointing = True
481+
_supports_gradient_checkpointing = True
482482
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
483483

484484
def _init_weights(self, module):

tests/test_modeling_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,21 @@ def recursive_check(tuple_object, dict_object):
246246
outputs_tuple = model(**inputs_dict, return_dict=False)
247247

248248
recursive_check(outputs_tuple, outputs_dict)
249+
250+
def test_enable_disable_gradient_checkpointing(self):
251+
if not self.model_class._supports_gradient_checkpointing:
252+
return # Skip test if model does not support gradient checkpointing
253+
254+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
255+
256+
# at init model should have gradient checkpointing disabled
257+
model = self.model_class(**init_dict)
258+
self.assertFalse(model.is_gradient_checkpointing)
259+
260+
# check enable works
261+
model.enable_gradient_checkpointing()
262+
self.assertTrue(model.is_gradient_checkpointing)
263+
264+
# check disable works
265+
model.disable_gradient_checkpointing()
266+
self.assertFalse(model.is_gradient_checkpointing)

tests/test_models_unet.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020

21-
from diffusers import UNet2DModel
21+
from diffusers import UNet2DConditionModel, UNet2DModel
2222
from diffusers.testing_utils import floats_tensor, slow, torch_device
2323

2424
from .test_modeling_common import ModelTesterMixin
@@ -159,6 +159,82 @@ def test_output_pretrained(self):
159159
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
160160

161161

162+
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
163+
model_class = UNet2DConditionModel
164+
165+
@property
166+
def dummy_input(self):
167+
batch_size = 4
168+
num_channels = 4
169+
sizes = (32, 32)
170+
171+
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
172+
time_step = torch.tensor([10]).to(torch_device)
173+
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
174+
175+
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
176+
177+
@property
178+
def input_shape(self):
179+
return (4, 32, 32)
180+
181+
@property
182+
def output_shape(self):
183+
return (4, 32, 32)
184+
185+
def prepare_init_args_and_inputs_for_common(self):
186+
init_dict = {
187+
"block_out_channels": (32, 64),
188+
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
189+
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
190+
"cross_attention_dim": 32,
191+
"attention_head_dim": 8,
192+
"out_channels": 4,
193+
"in_channels": 4,
194+
"layers_per_block": 2,
195+
"sample_size": 32,
196+
}
197+
inputs_dict = self.dummy_input
198+
return init_dict, inputs_dict
199+
200+
def test_gradient_checkpointing(self):
201+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
202+
model = self.model_class(**init_dict)
203+
model.to(torch_device)
204+
205+
out = model(**inputs_dict).sample
206+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
207+
# we won't calculate the loss and rather backprop on out.sum()
208+
model.zero_grad()
209+
out.sum().backward()
210+
211+
# now we save the output and parameter gradients that we will use for comparison purposes with
212+
# the non-checkpointed run.
213+
output_not_checkpointed = out.data.clone()
214+
grad_not_checkpointed = {}
215+
for name, param in model.named_parameters():
216+
grad_not_checkpointed[name] = param.grad.data.clone()
217+
218+
model.enable_gradient_checkpointing()
219+
out = model(**inputs_dict).sample
220+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
221+
# we won't calculate the loss and rather backprop on out.sum()
222+
model.zero_grad()
223+
out.sum().backward()
224+
225+
# now we save the output and parameter gradients that we will use for comparison purposes with
226+
# the non-checkpointed run.
227+
output_checkpointed = out.data.clone()
228+
grad_checkpointed = {}
229+
for name, param in model.named_parameters():
230+
grad_checkpointed[name] = param.grad.data.clone()
231+
232+
# compare the output and parameters gradients
233+
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
234+
for name in grad_checkpointed:
235+
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
236+
237+
162238
# TODO(Patrick) - Re-add this test after having cleaned up LDM
163239
# def test_output_pretrained_spatial_transformer(self):
164240
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")

0 commit comments

Comments
 (0)