@@ -677,6 +677,21 @@ def fuse_projections(self, fuse=True):
677677 concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
678678 self .to_kv .bias .copy_ (concatenated_bias )
679679
680+ # handle added projections for SD3 and others.
681+ if hasattr (self , "add_q_proj" ) and hasattr (self , "add_k_proj" ) and hasattr (self , "add_v_proj" ):
682+ concatenated_weights = torch .cat (
683+ [self .add_q_proj .weight .data , self .add_k_proj .weight .data , self .add_v_proj .weight .data ]
684+ )
685+ in_features = concatenated_weights .shape [1 ]
686+ out_features = concatenated_weights .shape [0 ]
687+
688+ self .to_added_qkv = nn .Linear (in_features , out_features , bias = True , device = device , dtype = dtype )
689+ self .to_added_qkv .weight .copy_ (concatenated_weights )
690+ concatenated_bias = torch .cat (
691+ [self .add_q_proj .bias .data , self .add_k_proj .bias .data , self .add_v_proj .bias .data ]
692+ )
693+ self .to_added_qkv .bias .copy_ (concatenated_bias )
694+
680695 self .fused_projections = fuse
681696
682697
@@ -1708,6 +1723,109 @@ def __call__(
17081723 return hidden_states
17091724
17101725
1726+ class FusedHunyuanAttnProcessor2_0 :
1727+ r"""
1728+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
1729+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
1730+ query and key vector.
1731+ """
1732+
1733+ def __init__ (self ):
1734+ if not hasattr (F , "scaled_dot_product_attention" ):
1735+ raise ImportError (
1736+ "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1737+ )
1738+
1739+ def __call__ (
1740+ self ,
1741+ attn : Attention ,
1742+ hidden_states : torch .Tensor ,
1743+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
1744+ attention_mask : Optional [torch .Tensor ] = None ,
1745+ temb : Optional [torch .Tensor ] = None ,
1746+ image_rotary_emb : Optional [torch .Tensor ] = None ,
1747+ ) -> torch .Tensor :
1748+ from .embeddings import apply_rotary_emb
1749+
1750+ residual = hidden_states
1751+ if attn .spatial_norm is not None :
1752+ hidden_states = attn .spatial_norm (hidden_states , temb )
1753+
1754+ input_ndim = hidden_states .ndim
1755+
1756+ if input_ndim == 4 :
1757+ batch_size , channel , height , width = hidden_states .shape
1758+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1759+
1760+ batch_size , sequence_length , _ = (
1761+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1762+ )
1763+
1764+ if attention_mask is not None :
1765+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
1766+ # scaled_dot_product_attention expects attention_mask shape to be
1767+ # (batch, heads, source_length, target_length)
1768+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
1769+
1770+ if attn .group_norm is not None :
1771+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
1772+
1773+ if encoder_hidden_states is None :
1774+ qkv = attn .to_qkv (hidden_states )
1775+ split_size = qkv .shape [- 1 ] // 3
1776+ query , key , value = torch .split (qkv , split_size , dim = - 1 )
1777+ else :
1778+ if attn .norm_cross :
1779+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1780+ query = attn .to_q (hidden_states )
1781+
1782+ kv = attn .to_kv (encoder_hidden_states )
1783+ split_size = kv .shape [- 1 ] // 2
1784+ key , value = torch .split (kv , split_size , dim = - 1 )
1785+
1786+ inner_dim = key .shape [- 1 ]
1787+ head_dim = inner_dim // attn .heads
1788+
1789+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1790+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1791+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1792+
1793+ if attn .norm_q is not None :
1794+ query = attn .norm_q (query )
1795+ if attn .norm_k is not None :
1796+ key = attn .norm_k (key )
1797+
1798+ # Apply RoPE if needed
1799+ if image_rotary_emb is not None :
1800+ query = apply_rotary_emb (query , image_rotary_emb )
1801+ if not attn .is_cross_attention :
1802+ key = apply_rotary_emb (key , image_rotary_emb )
1803+
1804+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1805+ # TODO: add support for attn.scale when we move to Torch 2.1
1806+ hidden_states = F .scaled_dot_product_attention (
1807+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1808+ )
1809+
1810+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1811+ hidden_states = hidden_states .to (query .dtype )
1812+
1813+ # linear proj
1814+ hidden_states = attn .to_out [0 ](hidden_states )
1815+ # dropout
1816+ hidden_states = attn .to_out [1 ](hidden_states )
1817+
1818+ if input_ndim == 4 :
1819+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1820+
1821+ if attn .residual_connection :
1822+ hidden_states = hidden_states + residual
1823+
1824+ hidden_states = hidden_states / attn .rescale_output_factor
1825+
1826+ return hidden_states
1827+
1828+
17111829class LuminaAttnProcessor2_0 :
17121830 r"""
17131831 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
0 commit comments