Skip to content

Commit 1ca0a75

Browse files
authored
refactor 3d rope for cogvideox (huggingface#9269)
* refactor 3d rope * repeat -> expand
1 parent c1e6a32 commit 1ca0a75

File tree

2 files changed

+35
-52
lines changed

2 files changed

+35
-52
lines changed

src/diffusers/models/embeddings.py

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
391391
The size of the temporal dimension.
392392
theta (`float`):
393393
Scaling factor for frequency computation.
394-
use_real (`bool`):
395-
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
396394
397395
Returns:
398396
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
399397
"""
398+
if use_real is not True:
399+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
400400
start, stop = crops_coords
401-
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
402-
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
401+
grid_size_h, grid_size_w = grid_size
402+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
403+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
403404
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
404405

405406
# Compute dimensions for each axis
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
408409
dim_w = embed_dim // 8 * 3
409410

410411
# Temporal frequencies
411-
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
412-
grid_t = torch.from_numpy(grid_t).float()
413-
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
414-
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
415-
412+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
416413
# Spatial frequencies for height and width
417-
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
418-
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
419-
grid_h = torch.from_numpy(grid_h).float()
420-
grid_w = torch.from_numpy(grid_w).float()
421-
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
422-
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
423-
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
424-
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
425-
426-
# Broadcast and concatenate tensors along specified dimension
427-
def broadcast(tensors, dim=-1):
428-
num_tensors = len(tensors)
429-
shape_lens = {len(t.shape) for t in tensors}
430-
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
431-
shape_len = list(shape_lens)[0]
432-
dim = (dim + shape_len) if dim < 0 else dim
433-
dims = list(zip(*(list(t.shape) for t in tensors)))
434-
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
435-
assert all(
436-
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
437-
), "invalid dimensions for broadcastable concatenation"
438-
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
439-
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
440-
expanded_dims.insert(dim, (dim, dims[dim]))
441-
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
442-
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
443-
return torch.cat(tensors, dim=dim)
444-
445-
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
446-
447-
t, h, w, d = freqs.shape
448-
freqs = freqs.view(t * h * w, d)
449-
450-
# Generate sine and cosine components
451-
sin = freqs.sin()
452-
cos = freqs.cos()
453-
454-
if use_real:
455-
return cos, sin
456-
else:
457-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
458-
return freqs_cis
414+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
415+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
416+
417+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
418+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
419+
freqs_t = freqs_t[:, None, None, :].expand(
420+
-1, grid_size_h, grid_size_w, -1
421+
) # temporal_size, grid_size_h, grid_size_w, dim_t
422+
freqs_h = freqs_h[None, :, None, :].expand(
423+
temporal_size, -1, grid_size_w, -1
424+
) # temporal_size, grid_size_h, grid_size_2, dim_h
425+
freqs_w = freqs_w[None, None, :, :].expand(
426+
temporal_size, grid_size_h, -1, -1
427+
) # temporal_size, grid_size_h, grid_size_2, dim_w
428+
429+
freqs = torch.cat(
430+
[freqs_t, freqs_h, freqs_w], dim=-1
431+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
432+
freqs = freqs.view(
433+
temporal_size * grid_size_h * grid_size_w, -1
434+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
435+
return freqs
436+
437+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
438+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
439+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
440+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
441+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
442+
return cos, sin
459443

460444

461445
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ def _prepare_rotary_positional_embeddings(
463463
crops_coords=grid_crops_coords,
464464
grid_size=(grid_height, grid_width),
465465
temporal_size=num_frames,
466-
use_real=True,
467466
)
468467

469468
freqs_cos = freqs_cos.to(device=device)

0 commit comments

Comments
 (0)