Skip to content

Commit de14261

Browse files
prathikrPrathik Raoroot
authored
Make UNet2DConditionOutput pickle-able (huggingface#3857)
* add default to unet output to prevent it from being a required arg * add unit test * make style * adjust unit test * mark as fast test * adjust assert statement in test --------- Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
1 parent 41ea88f commit de14261

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class UNet2DConditionOutput(BaseOutput):
5757
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
5858
"""
5959

60-
sample: torch.FloatTensor
60+
sample: torch.FloatTensor = None
6161

6262

6363
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):

tests/models/test_models_unet_2d_condition.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import gc
1718
import os
1819
import tempfile
@@ -782,6 +783,22 @@ def test_custom_diffusion_xformers_on_off(self):
782783
assert (sample - on_sample).abs().max() < 1e-4
783784
assert (sample - off_sample).abs().max() < 1e-4
784785

786+
def test_pickle(self):
787+
# enable deterministic behavior for gradient checkpointing
788+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
789+
790+
init_dict["attention_head_dim"] = (8, 16)
791+
792+
model = self.model_class(**init_dict)
793+
model.to(torch_device)
794+
795+
with torch.no_grad():
796+
sample = model(**inputs_dict).sample
797+
798+
sample_copy = copy.copy(sample)
799+
800+
assert (sample - sample_copy).abs().max() < 1e-4
801+
785802

786803
@slow
787804
class UNet2DConditionModelIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)