Skip to content

Conversation

@kakukakujirori
Copy link
Contributor

What does this PR do?

A small bug is in LTXImageToVideoPipeline.prepare_latents() when latents is already set.
latents assumes five-dimensional input (batch, channel, num_frames, height, width) as we can see from the line

num_frames = (
    (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)

However, when latents is set in the argument, the code skips applying self._pack_latents().

Also, the shape of conditioning_mask is wrong.

This PR addresses these two issues.

"""Code snippet to see the error
"""

import torch
from diffusers import LTXImageToVideoPipeline

device = "cuda:0"

# instantiate a pipeline
pipe = LTXImageToVideoPipeline.from_pretrained(
    "a-r-r-o-w/LTX-Video-0.9.1-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload(device=device)

# create a dummy latents tensor
num_frames = 49
height = 352
width = 640

latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
latent_height = height // pipe.vae_spatial_compression_ratio
latent_width = width // pipe.vae_spatial_compression_ratio

latents = torch.randn((1, 128, latent_num_frames, latent_height, latent_width), device=device)

# run
pipe(
    height=height,
    width=width,
    num_frames=num_frames,
    prompt="test_test",
    latents=latents,
)
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 28
     25 latents = torch.randn((1, 128, latent_num_frames, latent_height, latent_width), device=device)
     27 # run
---> 28 pipe(
     29     height=height,
     30     width=width,
     31     num_frames=num_frames,
     32     prompt="test_test",
     33     latents=latents,
     34 )

File ~/miniconda3/envs/py311/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/py311/lib/python3.11/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py:779, in LTXImageToVideoPipeline.__call__(self, image, prompt, negative_prompt, height, width, num_frames, frame_rate, num_inference_steps, timesteps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, decode_timestep, decode_noise_scale, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    777 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
    778 timestep = t.expand(latent_model_input.shape[0])
--> 779 timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
    781 noise_pred = self.transformer(
    782     hidden_states=latent_model_input,
    783     encoder_hidden_states=prompt_embeds,
   (...)
    791     return_dict=False,
    792 )[0]
    793 noise_pred = noise_pred.float()

RuntimeError: The size of tensor a (2) must match the size of tensor b (1540) at non-singleton dimension 1

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@a-r-r-o-w
Copy link
Contributor

I don't think there's a mistake with handling latents here. When user calls prepare_latents with their own latents, it is assumed to already be "prepared" (in this case, packed into ndim=3 tensor) and the only operation we wish to perform on the latent is device and dtype casting.

Regarding the mask shape, I believe that might be an actual mistake. Could you try running inference with only the mask_shape related change and passing ndim=3 latent?

@kakukakujirori
Copy link
Contributor Author

When user calls prepare_latents with their own latents, it is assumed to already be "prepared" (in this case, packed into ndim=3 tensor)

This case also fails. Since the packed latents is of shape (batch, num_patch, num_channel), the line

num_frames = (
    (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)

becomes equal to num_channel, which shouldn't be expected.

The following is the result, where

  • The input has been packed before being fed to the pipeline
  • latents packing is removed from prepare_latents()
"""Code snippet to see the error
"""

import torch
from diffusers import LTXImageToVideoPipeline

device = "cuda:0"

# instantiate a pipeline
pipe = LTXImageToVideoPipeline.from_pretrained(
    "a-r-r-o-w/LTX-Video-0.9.1-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload(device=device)

# create a dummy latents tensor
num_frames = 49
height = 352
width = 640

latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
latent_height = height // pipe.vae_spatial_compression_ratio
latent_width = width // pipe.vae_spatial_compression_ratio

latents = torch.randn((1, 128, latent_num_frames, latent_height, latent_width), device=device)

latents = pipe._pack_latents(latents, pipe.transformer_spatial_patch_size, pipe.transformer_temporal_patch_size)

# run
pipe(
    height=height,
    width=width,
    num_frames=num_frames,
    prompt="test_test",
    latents=latents,
)
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 30
     27 latents = pipe._pack_latents(latents, pipe.transformer_spatial_patch_size, pipe.transformer_temporal_patch_size)
     29 # run
---> 30 pipe(
     31     height=height,
     32     width=width,
     33     num_frames=num_frames,
     34     prompt="test_test",
     35     latents=latents,
     36 )

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py:784, in LTXImageToVideoPipeline.__call__(self, image, prompt, negative_prompt, height, width, num_frames, frame_rate, num_inference_steps, timesteps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, decode_timestep, decode_noise_scale, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    781 timestep = t.expand(latent_model_input.shape[0])
    782 timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
--> 784 noise_pred = self.transformer(
    785     hidden_states=latent_model_input,
    786     encoder_hidden_states=prompt_embeds,
    787     timestep=timestep,
    788     encoder_attention_mask=prompt_attention_mask,
    789     num_frames=latent_num_frames,
    790     height=latent_height,
    791     width=latent_width,
    792     rope_interpolation_scale=rope_interpolation_scale,
    793     attention_kwargs=attention_kwargs,
    794     return_dict=False,
    795 )[0]
    796 noise_pred = noise_pred.float()
    798 if self.do_classifier_free_guidance:

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/diffusers/models/transformers/transformer_ltx.py:440, in LTXVideoTransformer3DModel.forward(self, hidden_states, encoder_hidden_states, timestep, encoder_attention_mask, num_frames, height, width, rope_interpolation_scale, attention_kwargs, return_dict)
    430         hidden_states = torch.utils.checkpoint.checkpoint(
    431             create_custom_forward(block),
    432             hidden_states,
   (...)
    437             **ckpt_kwargs,
    438         )
    439     else:
--> 440         hidden_states = block(
    441             hidden_states=hidden_states,
    442             encoder_hidden_states=encoder_hidden_states,
    443             temb=temb,
    444             image_rotary_emb=image_rotary_emb,
    445             encoder_attention_mask=encoder_attention_mask,
    446         )
    448 scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
    449 shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/diffusers/models/transformers/transformer_ltx.py:245, in LTXVideoTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb, encoder_attention_mask)
    243 ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
    244 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
--> 245 norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
    247 attn_hidden_states = self.attn1(
    248     hidden_states=norm_hidden_states,
    249     encoder_hidden_states=None,
    250     image_rotary_emb=image_rotary_emb,
    251 )
    252 hidden_states = hidden_states + attn_hidden_states * gate_msa

RuntimeError: The size of tensor a (1540) must match the size of tensor b (28160) at non-singleton dimension 1

@a-r-r-o-w
Copy link
Contributor

Ohh okay, I see! nice catch 🔥

cc @yiyixuxu What do we want to do here? Accept fully prepared latents from the user (ndim=3) and do a fix for that, or accept ndim=5 tensor and prepare it

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 4, 2025

cc @yiyixuxu What do we want to do here? Accept fully prepared latents from the user (ndim=3) and do a fix for that, or accept ndim=5 tensor and prepare it

I think it should be fully prepared latents (output of prepare_latents) and do a fix for that

@kakukakujirori
Copy link
Contributor Author

Fixed. We can check the validity using the same code snippet above.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


if latents is not None:
conditioning_mask = latents.new_zeros(shape)
conditioning_mask = latents.new_zeros(mask_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normal route of prepare_latents() outputs conditioning_mask with that shape, so it is natural to align with it (here).

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 29, 2025
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Mar 31, 2025
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!
sorry I let this PR go stale!

@yiyixuxu yiyixuxu merged commit e8fc8b1 into huggingface:main Mar 31, 2025
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants