Skip to content

Conversation

@hlky
Copy link
Contributor

@hlky hlky commented Dec 9, 2024

What does this PR do?

Refactors get_3d_rotary_pos_embed and get_3d_rotary_pos_embed_allegro to use torch instead of numpy, and adds device argument so that tensors can be created on e.g. cuda.

Usage of get_3d_rotary_pos_embed and get_3d_rotary_pos_embed_allegro is updated to pass device where applicable (we don't specify device during initialization so we don't pass device to the function when used from init, the device from weights would just be cpu)

torch and numpy versions match numerically.

Reproduction get_3d_rotary_pos_embed
from diffusers.models.embeddings import get_1d_rotary_pos_embed, get_3d_rotary_pos_embed
import torch
from typing import Optional, Tuple, Union

embed_dim = 1920
sample_height = 60
sample_width = 90
sample_frames = 49
patch_size = 2
patch_size_t = 2
temporal_compression_ratio = 4
spatial_interpolation_scale = 1.875
temporal_interpolation_scale = 1.0
vae_scale_factor_spatial = 8
attention_head_dim = 64
num_frames = 49
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames

height = sample_height * vae_scale_factor_spatial
width = sample_width * vae_scale_factor_spatial

grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
p = patch_size
p_t = patch_size_t
base_size_width = sample_width // p
base_size_height = sample_height // p

def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
  tw = tgt_width
  th = tgt_height
  h, w = src
  r = h / w
  if r > (th / tw):
      resize_height = th
      resize_width = int(round(th / h * w))
  else:
      resize_width = tw
      resize_height = int(round(tw / w * h))

  crop_top = int(round((th - resize_height) / 2.0))
  crop_left = int(round((tw - resize_width) / 2.0))

  return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)

def get_3d_rotary_pos_embed_torch(
  embed_dim,
  crops_coords,
  grid_size,
  temporal_size,
  theta: int = 10000,
  use_real: bool = True,
  grid_type: str = "linspace",
  max_size: Optional[Tuple[int, int]] = None,
  device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  """
  RoPE for video tokens with 3D structure.

  Args:
  embed_dim: (`int`):
      The embedding dimension size, corresponding to hidden_size_head.
  crops_coords (`Tuple[int]`):
      The top-left and bottom-right coordinates of the crop.
  grid_size (`Tuple[int]`):
      The grid size of the spatial positional embedding (height, width).
  temporal_size (`int`):
      The size of the temporal dimension.
  theta (`float`):
      Scaling factor for frequency computation.
  grid_type (`str`):
      Whether to use "linspace" or "slice" to compute grids.

  Returns:
      `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
  """
  if use_real is not True:
      raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")

  if grid_type == "linspace":
      start, stop = crops_coords
      grid_size_h, grid_size_w = grid_size
      grid_h = torch.linspace(start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32)
      grid_w = torch.linspace(start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32)
      grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
      grid_t = torch.linspace(0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32)
  elif grid_type == "slice":
      max_h, max_w = max_size
      grid_size_h, grid_size_w = grid_size
      grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
      grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
      grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
  else:
      raise ValueError("Invalid value passed for `grid_type`.")

  # Compute dimensions for each axis
  dim_t = embed_dim // 4
  dim_h = embed_dim // 8 * 3
  dim_w = embed_dim // 8 * 3

  # Temporal frequencies
  freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
  # Spatial frequencies for height and width
  freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
  freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)

  # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
  def combine_time_height_width(freqs_t, freqs_h, freqs_w):
      freqs_t = freqs_t[:, None, None, :].expand(
          -1, grid_size_h, grid_size_w, -1
      )  # temporal_size, grid_size_h, grid_size_w, dim_t
      freqs_h = freqs_h[None, :, None, :].expand(
          temporal_size, -1, grid_size_w, -1
      )  # temporal_size, grid_size_h, grid_size_2, dim_h
      freqs_w = freqs_w[None, None, :, :].expand(
          temporal_size, grid_size_h, -1, -1
      )  # temporal_size, grid_size_h, grid_size_2, dim_w

      freqs = torch.cat(
          [freqs_t, freqs_h, freqs_w], dim=-1
      )  # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
      freqs = freqs.view(
          temporal_size * grid_size_h * grid_size_w, -1
      )  # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
      return freqs

  t_cos, t_sin = freqs_t  # both t_cos and t_sin has shape: temporal_size, dim_t
  h_cos, h_sin = freqs_h  # both h_cos and h_sin has shape: grid_size_h, dim_h
  w_cos, w_sin = freqs_w  # both w_cos and w_sin has shape: grid_size_w, dim_w

  if grid_type == "slice":
      t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
      h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
      w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]

  cos = combine_time_height_width(t_cos, h_cos, w_cos)
  sin = combine_time_height_width(t_sin, h_sin, w_sin)
  return cos, sin

grid_crops_coords = get_resize_crop_region_for_grid(
  (grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos_np, freqs_sin_np = get_3d_rotary_pos_embed(
  embed_dim=attention_head_dim,
  crops_coords=grid_crops_coords,
  grid_size=(grid_height, grid_width),
  temporal_size=num_frames,
)

freqs_cos, freqs_sin = get_3d_rotary_pos_embed_torch(
  embed_dim=attention_head_dim,
  crops_coords=grid_crops_coords,
  grid_size=(grid_height, grid_width),
  temporal_size=num_frames,
)

torch.testing.assert_close(freqs_cos, freqs_cos_np)
torch.testing.assert_close(freqs_sin, freqs_sin_np)

base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos_np, freqs_sin_np = get_3d_rotary_pos_embed(
  embed_dim=attention_head_dim,
  crops_coords=None,
  grid_size=(grid_height, grid_width),
  temporal_size=base_num_frames,
  grid_type="slice",
  max_size=(base_size_height, base_size_width),
)

freqs_cos, freqs_sin = get_3d_rotary_pos_embed_torch(
  embed_dim=attention_head_dim,
  crops_coords=None,
  grid_size=(grid_height, grid_width),
  temporal_size=base_num_frames,
  grid_type="slice",
  max_size=(base_size_height, base_size_width),
)

torch.testing.assert_close(freqs_cos, freqs_cos_np)
torch.testing.assert_close(freqs_sin, freqs_sin_np)
Reproduction get_3d_rotary_pos_embed_allegro
from diffusers.models.embeddings import get_1d_rotary_pos_embed, get_3d_rotary_pos_embed_allegro
import torch
from typing import Optional, Tuple, Union

sample_height = 160
sample_width = 90
sample_frames = 22
patch_size = 2

vae_scale_factor_spatial = 8
attention_head_dim = 96
num_frames = 22

height = sample_height * vae_scale_factor_spatial
width = sample_width * vae_scale_factor_spatial

grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
start, stop = (0, 0), (grid_height, grid_width)

interpolation_scale_h = 2.0
interpolation_scale_w = 2.0
interpolation_scale_t = 2.2

freqs_t_np, freqs_h_np, freqs_w_np, grid_t_np, grid_h_np, grid_w_np = get_3d_rotary_pos_embed_allegro(
  embed_dim=attention_head_dim,
  crops_coords=(start, stop),
  grid_size=(grid_height, grid_width),
  temporal_size=num_frames,
  interpolation_scale=(
      interpolation_scale_t,
      interpolation_scale_h,
      interpolation_scale_w,
  ),
)


def get_3d_rotary_pos_embed_allegro_torch(
  embed_dim,
  crops_coords,
  grid_size,
  temporal_size,
  interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
  theta: int = 10000,
  device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  # TODO(aryan): docs
  start, stop = crops_coords
  grid_size_h, grid_size_w = grid_size
  interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
  grid_t = torch.linspace(0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32)
  grid_h = torch.linspace(start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32)
  grid_w = torch.linspace(start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32)

  # Compute dimensions for each axis
  dim_t = embed_dim // 3
  dim_h = embed_dim // 3
  dim_w = embed_dim // 3

  # Temporal frequencies
  freqs_t = get_1d_rotary_pos_embed(
      dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
  )
  # Spatial frequencies for height and width
  freqs_h = get_1d_rotary_pos_embed(
      dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
  )
  freqs_w = get_1d_rotary_pos_embed(
      dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
  )

  return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w

freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro_torch(
  embed_dim=attention_head_dim,
  crops_coords=(start, stop),
  grid_size=(grid_height, grid_width),
  temporal_size=num_frames,
  interpolation_scale=(
      interpolation_scale_t,
      interpolation_scale_h,
      interpolation_scale_w,
  ),
)

torch.testing.assert_close(freqs_t, freqs_t_np)
torch.testing.assert_close(freqs_h, freqs_h_np)
torch.testing.assert_close(freqs_w, freqs_w_np)
torch.testing.assert_close(grid_t, torch.from_numpy(grid_t_np))
torch.testing.assert_close(grid_h, torch.from_numpy(grid_h_np))
torch.testing.assert_close(grid_w, torch.from_numpy(grid_w_np))

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 9, 2024

hi @hlky
Change looks good to me code-wise! I will ask @a-r-r-o-w to do a review
other than that, for this type of refactor, let's run all the slow tests for models/pipelines affected. I will actually run all the docstring examples for them and make sure the outputs are the same.

Copy link
Contributor

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me, but probably need to be mindful if this will affect downstream repositories that rely on these functions. There are atleast two that come to mind immediately:

There are probably a lot more usages. I think it is a safe change since the repositories that already rely on these methods are doing the device casting after the tensors are returned. Should be okay to merge if slow tests pass IMO!

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Downstream usage of get_3d_rotary_pos_embed - looks like these handle device casting so should be ok. Will update after slow tests are run.

@hlky
Copy link
Contributor Author

hlky commented Dec 10, 2024

Cogvideo slow test - passes except 1 OOM, unrelated, same OOM here

@yiyixuxu yiyixuxu merged commit 4c4b323 into huggingface:main Dec 10, 2024
15 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
Use torch in get_3d_rotary_pos_embed/_allegro
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants