-
Notifications
You must be signed in to change notification settings - Fork 6k
[tests] add tests for framepack transformer model. #11520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
It is expected that clean_x_embedder and x_embedder are put on the same device for this to pass. It is because accelerate performs the device allocation for different layers based on their initialization order. Moving the diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
index 0331d9934..012a6e532 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
@@ -152,9 +152,14 @@ class HunyuanVideoFramepackTransformer3DModel(
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+ self.clean_x_embedder = None
+ if has_clean_x_embedder:
+ self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
+ # Framepack specific modules
+ self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
@@ -186,13 +191,6 @@ class HunyuanVideoFramepackTransformer3DModel(
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
- # Framepack specific modules
- self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
-
- self.clean_x_embedder = None
- if has_clean_x_embedder:
- self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
-
self.gradient_checkpointing = False
def forward( But, this is not a "correct" fix in the general case. We need to put in device handling code in the concatenate statements for it to work as expected in the correct way. Something like: hidden_states = torch.cat([latents_clean.to(hidden_states), hidden_states], dim=1) It makes the code look unnecessarily complicated IMO since it is expected that these would already be on the correct device/dtype in the single GPU case. If we'd like to make these changes anyway, LMK and I'll open a PR. |
Exactly why I didn't make these changes because I strongly echo you opinions on it. So, given that, I think it's still preferable to go with the other option you mentioned i.e., corresponding to the initialization order. |
#11535 should hopefully fix the error you're seeing for Framepack. For Hunyuan Video, the following patch is required: --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -112,11 +112,12 @@ class HunyuanVideoAttnProcessor2_0:
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
- query = torch.cat([query, encoder_query], dim=2)
- key = torch.cat([key, encoder_key], dim=2)
- value = torch.cat([value, encoder_value], dim=2)
+ query = torch.cat([query, encoder_query.to(query)], dim=2)
+ key = torch.cat([key, encoder_key.to(key)], dim=2)
+ value = torch.cat([value, encoder_value.to(value)], dim=2)
# 5. Attention
+ key, value, attention_mask = (x.to(query.device) for x in (key, value, attention_mask))
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -865,8 +866,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
+ "HunyuanVideoConditionEmbedding",
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
+ "HunyuanVideoTokenReplaceTransformerBlock",
+ "HunyuanVideoTokenReplaceSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
] I think no split modules changes are okay, but the changes to attention processor seem to complicate easily readable code (same reasoning as not using the "correct" fix mentioned above). I don't think the model parallelism implementation was really meant to handle complex cases like this, similar to how group offloading does not really work as expected with MoE implementation. Probably better to skip the test making note of why it would fail, but up to you. Edit: The model parallel implementation can handle this if attention processors were nn.Module but since they are just a wrapper class, it does not have the necessary device-modifying hooks registered |
Thanks, would it be possible to skip them accordingly and batch in a separate PR? I have confirmed that your fixes in #11535 solve my initial issue. So, please have a look at this PR and LMK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, will skip in a separate PR
What does this PR do?
@a-r-r-o-w the following two model splitting tests are failing:
Could you take a look when you have time? There are similar failures in HunuyanVideo transformer model, too, just as an FYI. Also, cc: @SunMarc