@@ -1695,81 +1695,6 @@ def __call__(
16951695 return hidden_states
16961696
16971697
1698- # YiYi to-do: refactor rope related functions/classes
1699- def apply_rope (xq , xk , freqs_cis ):
1700- xq_ = xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 1 , 2 )
1701- xk_ = xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 1 , 2 )
1702- xq_out = freqs_cis [..., 0 ] * xq_ [..., 0 ] + freqs_cis [..., 1 ] * xq_ [..., 1 ]
1703- xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
1704- return xq_out .reshape (* xq .shape ).type_as (xq ), xk_out .reshape (* xk .shape ).type_as (xk )
1705-
1706-
1707- class FluxSingleAttnProcessor2_0 :
1708- r"""
1709- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1710- """
1711-
1712- def __init__ (self ):
1713- if not hasattr (F , "scaled_dot_product_attention" ):
1714- raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
1715-
1716- def __call__ (
1717- self ,
1718- attn : Attention ,
1719- hidden_states : torch .Tensor ,
1720- encoder_hidden_states : Optional [torch .Tensor ] = None ,
1721- attention_mask : Optional [torch .FloatTensor ] = None ,
1722- image_rotary_emb : Optional [torch .Tensor ] = None ,
1723- ) -> torch .Tensor :
1724- input_ndim = hidden_states .ndim
1725-
1726- if input_ndim == 4 :
1727- batch_size , channel , height , width = hidden_states .shape
1728- hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1729-
1730- batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1731-
1732- query = attn .to_q (hidden_states )
1733- if encoder_hidden_states is None :
1734- encoder_hidden_states = hidden_states
1735-
1736- key = attn .to_k (encoder_hidden_states )
1737- value = attn .to_v (encoder_hidden_states )
1738-
1739- inner_dim = key .shape [- 1 ]
1740- head_dim = inner_dim // attn .heads
1741-
1742- query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1743-
1744- key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1745- value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1746-
1747- if attn .norm_q is not None :
1748- query = attn .norm_q (query )
1749- if attn .norm_k is not None :
1750- key = attn .norm_k (key )
1751-
1752- # Apply RoPE if needed
1753- if image_rotary_emb is not None :
1754- # YiYi to-do: update uising apply_rotary_emb
1755- # from ..embeddings import apply_rotary_emb
1756- # query = apply_rotary_emb(query, image_rotary_emb)
1757- # key = apply_rotary_emb(key, image_rotary_emb)
1758- query , key = apply_rope (query , key , image_rotary_emb )
1759-
1760- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1761- # TODO: add support for attn.scale when we move to Torch 2.1
1762- hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1763-
1764- hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1765- hidden_states = hidden_states .to (query .dtype )
1766-
1767- if input_ndim == 4 :
1768- hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1769-
1770- return hidden_states
1771-
1772-
17731698class FluxAttnProcessor2_0 :
17741699 """Attention processor used typically in processing the SD3-like self-attention projections."""
17751700
@@ -1785,16 +1710,7 @@ def __call__(
17851710 attention_mask : Optional [torch .FloatTensor ] = None ,
17861711 image_rotary_emb : Optional [torch .Tensor ] = None ,
17871712 ) -> torch .FloatTensor :
1788- input_ndim = hidden_states .ndim
1789- if input_ndim == 4 :
1790- batch_size , channel , height , width = hidden_states .shape
1791- hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1792- context_input_ndim = encoder_hidden_states .ndim
1793- if context_input_ndim == 4 :
1794- batch_size , channel , height , width = encoder_hidden_states .shape
1795- encoder_hidden_states = encoder_hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1796-
1797- batch_size = encoder_hidden_states .shape [0 ]
1713+ batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
17981714
17991715 # `sample` projections.
18001716 query = attn .to_q (hidden_states )
@@ -1813,59 +1729,58 @@ def __call__(
18131729 if attn .norm_k is not None :
18141730 key = attn .norm_k (key )
18151731
1816- # `context` projections.
1817- encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1818- encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1819- encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
1732+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1733+ if encoder_hidden_states is not None :
1734+ # `context` projections.
1735+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1736+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1737+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
18201738
1821- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
1822- batch_size , - 1 , attn .heads , head_dim
1823- ).transpose (1 , 2 )
1824- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
1825- batch_size , - 1 , attn .heads , head_dim
1826- ).transpose (1 , 2 )
1827- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
1828- batch_size , - 1 , attn .heads , head_dim
1829- ).transpose (1 , 2 )
1830-
1831- if attn .norm_added_q is not None :
1832- encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
1833- if attn .norm_added_k is not None :
1834- encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
1739+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
1740+ batch_size , - 1 , attn .heads , head_dim
1741+ ).transpose (1 , 2 )
1742+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
1743+ batch_size , - 1 , attn .heads , head_dim
1744+ ).transpose (1 , 2 )
1745+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
1746+ batch_size , - 1 , attn .heads , head_dim
1747+ ).transpose (1 , 2 )
18351748
1836- # attention
1837- query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
1838- key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
1839- value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
1749+ if attn .norm_added_q is not None :
1750+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
1751+ if attn .norm_added_k is not None :
1752+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
1753+
1754+ # attention
1755+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
1756+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
1757+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
18401758
18411759 if image_rotary_emb is not None :
1842- # YiYi to-do: update uising apply_rotary_emb
1843- # from ..embeddings import apply_rotary_emb
1844- # query = apply_rotary_emb(query, image_rotary_emb)
1845- # key = apply_rotary_emb(key, image_rotary_emb)
1846- query , key = apply_rope (query , key , image_rotary_emb )
1760+ from .embeddings import apply_rotary_emb
1761+
1762+ query = apply_rotary_emb (query , image_rotary_emb )
1763+ key = apply_rotary_emb (key , image_rotary_emb )
18471764
18481765 hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
18491766 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
18501767 hidden_states = hidden_states .to (query .dtype )
18511768
1852- encoder_hidden_states , hidden_states = (
1853- hidden_states [:, : encoder_hidden_states .shape [1 ]],
1854- hidden_states [:, encoder_hidden_states .shape [1 ] :],
1855- )
1856-
1857- # linear proj
1858- hidden_states = attn .to_out [0 ](hidden_states )
1859- # dropout
1860- hidden_states = attn .to_out [1 ](hidden_states )
1861- encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1769+ if encoder_hidden_states is not None :
1770+ encoder_hidden_states , hidden_states = (
1771+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
1772+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
1773+ )
18621774
1863- if input_ndim == 4 :
1864- hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1865- if context_input_ndim == 4 :
1866- encoder_hidden_states = encoder_hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1775+ # linear proj
1776+ hidden_states = attn .to_out [0 ](hidden_states )
1777+ # dropout
1778+ hidden_states = attn .to_out [1 ](hidden_states )
1779+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
18671780
1868- return hidden_states , encoder_hidden_states
1781+ return hidden_states , encoder_hidden_states
1782+ else :
1783+ return hidden_states
18691784
18701785
18711786class XFormersAttnAddedKVProcessor :
@@ -4105,6 +4020,17 @@ def __init__(self):
41054020 pass
41064021
41074022
4023+ class FluxSingleAttnProcessor2_0 (FluxAttnProcessor2_0 ):
4024+ r"""
4025+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
4026+ """
4027+
4028+ def __init__ (self ):
4029+ deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
4030+ deprecate ("FluxSingleAttnProcessor2_0" , "0.32.0" , deprecation_message )
4031+ super ().__init__ ()
4032+
4033+
41084034ADDED_KV_ATTENTION_PROCESSORS = (
41094035 AttnAddedKVProcessor ,
41104036 SlicedAttnAddedKVProcessor ,
0 commit comments