@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
391391 The size of the temporal dimension.
392392 theta (`float`):
393393 Scaling factor for frequency computation.
394- use_real (`bool`):
395- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
396394
397395 Returns:
398396 `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
399397 """
398+ if use_real is not True :
399+ raise ValueError (" `use_real = False` is not currently supported for get_3d_rotary_pos_embed" )
400400 start , stop = crops_coords
401- grid_h = np .linspace (start [0 ], stop [0 ], grid_size [0 ], endpoint = False , dtype = np .float32 )
402- grid_w = np .linspace (start [1 ], stop [1 ], grid_size [1 ], endpoint = False , dtype = np .float32 )
401+ grid_size_h , grid_size_w = grid_size
402+ grid_h = np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
403+ grid_w = np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
403404 grid_t = np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
404405
405406 # Compute dimensions for each axis
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
408409 dim_w = embed_dim // 8 * 3
409410
410411 # Temporal frequencies
411- freqs_t = 1.0 / (theta ** (torch .arange (0 , dim_t , 2 ).float () / dim_t ))
412- grid_t = torch .from_numpy (grid_t ).float ()
413- freqs_t = torch .einsum ("n , f -> n f" , grid_t , freqs_t )
414- freqs_t = freqs_t .repeat_interleave (2 , dim = - 1 )
415-
412+ freqs_t = get_1d_rotary_pos_embed (dim_t , grid_t , use_real = True )
416413 # Spatial frequencies for height and width
417- freqs_h = 1.0 / (theta ** (torch .arange (0 , dim_h , 2 ).float () / dim_h ))
418- freqs_w = 1.0 / (theta ** (torch .arange (0 , dim_w , 2 ).float () / dim_w ))
419- grid_h = torch .from_numpy (grid_h ).float ()
420- grid_w = torch .from_numpy (grid_w ).float ()
421- freqs_h = torch .einsum ("n , f -> n f" , grid_h , freqs_h )
422- freqs_w = torch .einsum ("n , f -> n f" , grid_w , freqs_w )
423- freqs_h = freqs_h .repeat_interleave (2 , dim = - 1 )
424- freqs_w = freqs_w .repeat_interleave (2 , dim = - 1 )
425-
426- # Broadcast and concatenate tensors along specified dimension
427- def broadcast (tensors , dim = - 1 ):
428- num_tensors = len (tensors )
429- shape_lens = {len (t .shape ) for t in tensors }
430- assert len (shape_lens ) == 1 , "tensors must all have the same number of dimensions"
431- shape_len = list (shape_lens )[0 ]
432- dim = (dim + shape_len ) if dim < 0 else dim
433- dims = list (zip (* (list (t .shape ) for t in tensors )))
434- expandable_dims = [(i , val ) for i , val in enumerate (dims ) if i != dim ]
435- assert all (
436- [* (len (set (t [1 ])) <= 2 for t in expandable_dims )]
437- ), "invalid dimensions for broadcastable concatenation"
438- max_dims = [(t [0 ], max (t [1 ])) for t in expandable_dims ]
439- expanded_dims = [(t [0 ], (t [1 ],) * num_tensors ) for t in max_dims ]
440- expanded_dims .insert (dim , (dim , dims [dim ]))
441- expandable_shapes = list (zip (* (t [1 ] for t in expanded_dims )))
442- tensors = [t [0 ].expand (* t [1 ]) for t in zip (tensors , expandable_shapes )]
443- return torch .cat (tensors , dim = dim )
444-
445- freqs = broadcast ((freqs_t [:, None , None , :], freqs_h [None , :, None , :], freqs_w [None , None , :, :]), dim = - 1 )
446-
447- t , h , w , d = freqs .shape
448- freqs = freqs .view (t * h * w , d )
449-
450- # Generate sine and cosine components
451- sin = freqs .sin ()
452- cos = freqs .cos ()
453-
454- if use_real :
455- return cos , sin
456- else :
457- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
458- return freqs_cis
414+ freqs_h = get_1d_rotary_pos_embed (dim_h , grid_h , use_real = True )
415+ freqs_w = get_1d_rotary_pos_embed (dim_w , grid_w , use_real = True )
416+
417+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
418+ def combine_time_height_width (freqs_t , freqs_h , freqs_w ):
419+ freqs_t = freqs_t [:, None , None , :].expand (
420+ - 1 , grid_size_h , grid_size_w , - 1
421+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
422+ freqs_h = freqs_h [None , :, None , :].expand (
423+ temporal_size , - 1 , grid_size_w , - 1
424+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
425+ freqs_w = freqs_w [None , None , :, :].expand (
426+ temporal_size , grid_size_h , - 1 , - 1
427+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
428+
429+ freqs = torch .cat (
430+ [freqs_t , freqs_h , freqs_w ], dim = - 1
431+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
432+ freqs = freqs .view (
433+ temporal_size * grid_size_h * grid_size_w , - 1
434+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
435+ return freqs
436+
437+ t_cos , t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
438+ h_cos , h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
439+ w_cos , w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
440+ cos = combine_time_height_width (t_cos , h_cos , w_cos )
441+ sin = combine_time_height_width (t_sin , h_sin , w_sin )
442+ return cos , sin
459443
460444
461445def get_2d_rotary_pos_embed (embed_dim , crops_coords , grid_size , use_real = True ):
0 commit comments