Skip to content

torch.compile fullgraph compatibility for Hunyuan Video #11457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 29, 2025

Fixes #11431 (review)

Will run the full model for testing in some time.

testing code
import torch
from diffusers import HunyuanVideoTransformer3DModel

@torch.no_grad()
def main():
    device = "cuda"
    dtype = torch.bfloat16
    batch_size = 1
    num_channels = 4
    num_frames = 2
    height = 4
    width = 4
    sequence_length = 8
    
    transformer = HunyuanVideoTransformer3DModel(
        in_channels=4,
        out_channels=4,
        num_attention_heads=2,
        attention_head_dim=10,
        num_layers=2,
        num_single_layers=2,
        num_refiner_layers=1,
        patch_size=1,
        patch_size_t=1,
        guidance_embeds=True,
        text_embed_dim=16,
        pooled_projection_dim=8,
        rope_axes_dim=(2, 4, 4),
    ).to(device=device, dtype=dtype)

    hidden_states = torch.randn(batch_size, num_channels, num_frames, height, width, dtype=dtype, device=device)
    timestep = torch.randint(0, 1000, (batch_size,), dtype=torch.long, device=device)
    encoder_hidden_states = torch.randn(batch_size, sequence_length, 16, dtype=dtype, device=device)
    encoder_attention_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
    encoder_attention_mask[:, :sequence_length - 2] = True
    pooled_projections = torch.randn(batch_size, 8, dtype=dtype, device=device)
    guidance = torch.randint(0, 1000, (batch_size,), dtype=torch.long, device=device)

    transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
    output = transformer(
        hidden_states=hidden_states,
        timestep=timestep,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        pooled_projections=pooled_projections,
        guidance=guidance,
        return_dict=False,
    )[0]
    print(output.shape)


if __name__ == "__main__":
    main()

@a-r-r-o-w a-r-r-o-w requested review from sayakpaul and Copilot April 29, 2025 22:24
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves fullgraph compatibility for Hunyuan Video by replacing a loop‑based attention mask construction with a vectorized approach.

  • Replaces manual per‐batch loop with a vectorized masked_fill operation.
  • Updates the attention mask initialization from zeros to ones and adds appropriate unsqueezing for broadcasting.
Comments suppressed due to low confidence (1)

src/diffusers/models/transformers/transformer_hunyuan_video.py:1071

  • [nitpick] The refactored attention mask construction is more efficient; consider adding an inline comment that explains the logic behind initializing with ones and using masked_fill for clarity.
attention_mask = torch.ones(batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid, thanks for working on this!

To confirm, I first checked out the PR branch of #11431 and merged your PR branch there and then ran RUN_SLOW=1 RUN_COMPILE=1 pytest tests/models/transformers/test_models_transformer_hunyuan_video.py -k "test_torch_compile_recompilation_and_graph_break". Everything was green.

@sayakpaul sayakpaul added performance Anything related to performance improvements, profiling and benchmarking torch.compile labels Apr 30, 2025
@a-r-r-o-w a-r-r-o-w merged commit c865115 into main Apr 30, 2025
15 of 16 checks passed
@a-r-r-o-w a-r-r-o-w deleted the improve-hunyuan-compile-support branch April 30, 2025 05:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance improvements, profiling and benchmarking torch.compile
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants