Skip to content

[test_models_transformer_hunyuan_video] help us test torch.compile() for impactful models #11431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 30, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
update
  • Loading branch information
sayakpaul committed Apr 30, 2025
commit 1b4ba9af503f7da4eda9313275991197ffb06e2b
22 changes: 12 additions & 10 deletions tests/models/transformers/test_models_transformer_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
import torch

from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)

from ..test_modeling_common import ModelTesterMixin

Expand Down Expand Up @@ -86,18 +92,14 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down Expand Up @@ -191,7 +193,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
Expand Down Expand Up @@ -257,7 +259,7 @@ def test_output(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
Expand Down Expand Up @@ -336,7 +338,7 @@ def prepare_init_args_and_inputs_for_common(self):

def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))

def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Expand Down
Loading