Skip to content

Commit 8ad68c1

Browse files
Add missing MochiEncoder3D.gradient_checkpointing attribute (huggingface#11146)
* Add missing 'gradient_checkpointing = False' attr * Add (limited) tests for Mochi autoencoder * Apply style fixes * pass 'conv_cache' as arg instead of kwarg --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 41afb66 commit 8ad68c1

File tree

2 files changed

+120
-7
lines changed

2 files changed

+120
-7
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def forward(
210210
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
211211
resnet,
212212
hidden_states,
213-
conv_cache=conv_cache.get(conv_cache_key),
213+
conv_cache.get(conv_cache_key),
214214
)
215215
else:
216216
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -306,7 +306,7 @@ def forward(
306306

307307
if torch.is_grad_enabled() and self.gradient_checkpointing:
308308
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
309-
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
309+
resnet, hidden_states, conv_cache.get(conv_cache_key)
310310
)
311311
else:
312312
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -382,7 +382,7 @@ def forward(
382382
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
383383
resnet,
384384
hidden_states,
385-
conv_cache=conv_cache.get(conv_cache_key),
385+
conv_cache.get(conv_cache_key),
386386
)
387387
else:
388388
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -497,6 +497,8 @@ def __init__(
497497
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
498498
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
499499

500+
self.gradient_checkpointing = False
501+
500502
def forward(
501503
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
502504
) -> torch.Tensor:
@@ -513,13 +515,13 @@ def forward(
513515

514516
if torch.is_grad_enabled() and self.gradient_checkpointing:
515517
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
516-
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
518+
self.block_in, hidden_states, conv_cache.get("block_in")
517519
)
518520

519521
for i, down_block in enumerate(self.down_blocks):
520522
conv_cache_key = f"down_block_{i}"
521523
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
522-
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
524+
down_block, hidden_states, conv_cache.get(conv_cache_key)
523525
)
524526
else:
525527
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -623,13 +625,13 @@ def forward(
623625
# 1. Mid
624626
if torch.is_grad_enabled() and self.gradient_checkpointing:
625627
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
626-
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
628+
self.block_in, hidden_states, conv_cache.get("block_in")
627629
)
628630

629631
for i, up_block in enumerate(self.up_blocks):
630632
conv_cache_key = f"up_block_{i}"
631633
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
632-
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
634+
up_block, hidden_states, conv_cache.get(conv_cache_key)
633635
)
634636
else:
635637
hidden_states, new_conv_cache["block_in"] = self.block_in(
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
from diffusers import AutoencoderKLMochi
19+
from diffusers.utils.testing_utils import (
20+
enable_full_determinism,
21+
floats_tensor,
22+
torch_device,
23+
)
24+
25+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
26+
27+
28+
enable_full_determinism()
29+
30+
31+
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
32+
model_class = AutoencoderKLMochi
33+
main_input_name = "sample"
34+
base_precision = 1e-2
35+
36+
def get_autoencoder_kl_mochi_config(self):
37+
return {
38+
"in_channels": 15,
39+
"out_channels": 3,
40+
"latent_channels": 4,
41+
"encoder_block_out_channels": (32, 32, 32, 32),
42+
"decoder_block_out_channels": (32, 32, 32, 32),
43+
"layers_per_block": (1, 1, 1, 1, 1),
44+
"act_fn": "silu",
45+
"scaling_factor": 1,
46+
}
47+
48+
@property
49+
def dummy_input(self):
50+
batch_size = 2
51+
num_frames = 7
52+
num_channels = 3
53+
sizes = (16, 16)
54+
55+
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
56+
57+
return {"sample": image}
58+
59+
@property
60+
def input_shape(self):
61+
return (3, 7, 16, 16)
62+
63+
@property
64+
def output_shape(self):
65+
return (3, 7, 16, 16)
66+
67+
def prepare_init_args_and_inputs_for_common(self):
68+
init_dict = self.get_autoencoder_kl_mochi_config()
69+
inputs_dict = self.dummy_input
70+
return init_dict, inputs_dict
71+
72+
def test_gradient_checkpointing_is_applied(self):
73+
expected_set = {
74+
"MochiDecoder3D",
75+
"MochiDownBlock3D",
76+
"MochiEncoder3D",
77+
"MochiMidBlock3D",
78+
"MochiUpBlock3D",
79+
}
80+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
81+
82+
@unittest.skip("Unsupported test.")
83+
def test_forward_with_norm_groups(self):
84+
"""
85+
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
86+
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
87+
"""
88+
pass
89+
90+
@unittest.skip("Unsupported test.")
91+
def test_model_parallelism(self):
92+
"""
93+
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
94+
RuntimeError: values expected sparse tensor layout but got Strided
95+
"""
96+
pass
97+
98+
@unittest.skip("Unsupported test.")
99+
def test_outputs_equivalence(self):
100+
"""
101+
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
102+
RuntimeError: values expected sparse tensor layout but got Strided
103+
"""
104+
pass
105+
106+
@unittest.skip("Unsupported test.")
107+
def test_sharded_checkpoints_device_map(self):
108+
"""
109+
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_sharded_checkpoints_device_map -
110+
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:5!
111+
"""

0 commit comments

Comments
 (0)