Skip to content

[tests] add tests for framepack transformer model. #11520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 11, 2025

Conversation

sayakpaul
Copy link
Member

What does this PR do?

@a-r-r-o-w the following two model splitting tests are failing:

FAILED tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py::HunyuanVideoTransformer3DTests::test_model_parallelism - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argumen...
FAILED tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py::HunyuanVideoTransformer3DTests::test_sharded_checkpoints_device_map - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument

Could you take a look when you have time? There are similar failures in HunuyanVideo transformer model, too, just as an FYI. Also, cc: @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member

It is expected that clean_x_embedder and x_embedder are put on the same device for this to pass. It is because accelerate performs the device allocation for different layers based on their initialization order. Moving the clean_x_embedder layer initialization right below x_embedder, and image_projection layer right below context_embedder, will fix the error for the tests.

diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
index 0331d9934..012a6e532 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
@@ -152,9 +152,14 @@ class HunyuanVideoFramepackTransformer3DModel(
 
         # 1. Latent and condition embedders
         self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+        self.clean_x_embedder = None
+        if has_clean_x_embedder:
+            self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
         self.context_embedder = HunyuanVideoTokenRefiner(
             text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
         )
+        # Framepack specific modules
+        self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
         self.time_text_embed = HunyuanVideoConditionEmbedding(
             inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
         )
@@ -186,13 +191,6 @@ class HunyuanVideoFramepackTransformer3DModel(
         self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
         self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
 
-        # Framepack specific modules
-        self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
-
-        self.clean_x_embedder = None
-        if has_clean_x_embedder:
-            self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
-
         self.gradient_checkpointing = False
 
     def forward(

But, this is not a "correct" fix in the general case. We need to put in device handling code in the concatenate statements for it to work as expected in the correct way. Something like:

hidden_states = torch.cat([latents_clean.to(hidden_states), hidden_states], dim=1)

It makes the code look unnecessarily complicated IMO since it is expected that these would already be on the correct device/dtype in the single GPU case. If we'd like to make these changes anyway, LMK and I'll open a PR.

@sayakpaul
Copy link
Member Author

But, this is not a "correct" fix in the general case. We need to put in device handling code in the concatenate statements for it to work as expected in the correct way. Something like:

Exactly why I didn't make these changes because I strongly echo you opinions on it.

So, given that, I think it's still preferable to go with the other option you mentioned i.e., corresponding to the initialization order.

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented May 10, 2025

#11535 should hopefully fix the error you're seeing for Framepack. For Hunyuan Video, the following patch is required:

--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -112,11 +112,12 @@ class HunyuanVideoAttnProcessor2_0:
             if attn.norm_added_k is not None:
                 encoder_key = attn.norm_added_k(encoder_key)
 
-            query = torch.cat([query, encoder_query], dim=2)
-            key = torch.cat([key, encoder_key], dim=2)
-            value = torch.cat([value, encoder_value], dim=2)
+            query = torch.cat([query, encoder_query.to(query)], dim=2)
+            key = torch.cat([key, encoder_key.to(key)], dim=2)
+            value = torch.cat([value, encoder_value.to(value)], dim=2)
 
         # 5. Attention
+        key, value, attention_mask = (x.to(query.device) for x in (key, value, attention_mask))
         hidden_states = F.scaled_dot_product_attention(
             query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         )
@@ -865,8 +866,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
     _supports_gradient_checkpointing = True
     _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
     _no_split_modules = [
+        "HunyuanVideoConditionEmbedding",
         "HunyuanVideoTransformerBlock",
         "HunyuanVideoSingleTransformerBlock",
+        "HunyuanVideoTokenReplaceTransformerBlock",
+        "HunyuanVideoTokenReplaceSingleTransformerBlock",
         "HunyuanVideoPatchEmbed",
         "HunyuanVideoTokenRefiner",
     ]

I think no split modules changes are okay, but the changes to attention processor seem to complicate easily readable code (same reasoning as not using the "correct" fix mentioned above). I don't think the model parallelism implementation was really meant to handle complex cases like this, similar to how group offloading does not really work as expected with MoE implementation. Probably better to skip the test making note of why it would fail, but up to you.

Edit: The model parallel implementation can handle this if attention processors were nn.Module but since they are just a wrapper class, it does not have the necessary device-modifying hooks registered

@sayakpaul
Copy link
Member Author

I don't think the model parallelism implementation was really meant to handle complex cases like this, similar to how group offloading does not really work as expected with MoE implementation. Probably better to skip the test making note of why it would fail, but up to you.

Thanks, would it be possible to skip them accordingly and batch in a separate PR?

I have confirmed that your fixes in #11535 solve my initial issue. So, please have a look at this PR and LMK.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will skip in a separate PR

@sayakpaul sayakpaul merged commit 01abfc8 into main May 11, 2025
16 checks passed
@sayakpaul sayakpaul deleted the framepack-transformer-tests branch May 11, 2025 04:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants