Skip to content

Commit c291617

Browse files
yiyixuxusayakpauljsmidt
authored
Flux followup (huggingface#9074)
* refactor rotary embeds * adding jsmidt as co-author of this PR for huggingface#9133 --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Joseph Smidt <[email protected]>
1 parent 9003d75 commit c291617

File tree

8 files changed

+159
-202
lines changed

8 files changed

+159
-202
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 53 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,81 +1695,6 @@ def __call__(
16951695
return hidden_states
16961696

16971697

1698-
# YiYi to-do: refactor rope related functions/classes
1699-
def apply_rope(xq, xk, freqs_cis):
1700-
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
1701-
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
1702-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
1703-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
1704-
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
1705-
1706-
1707-
class FluxSingleAttnProcessor2_0:
1708-
r"""
1709-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1710-
"""
1711-
1712-
def __init__(self):
1713-
if not hasattr(F, "scaled_dot_product_attention"):
1714-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1715-
1716-
def __call__(
1717-
self,
1718-
attn: Attention,
1719-
hidden_states: torch.Tensor,
1720-
encoder_hidden_states: Optional[torch.Tensor] = None,
1721-
attention_mask: Optional[torch.FloatTensor] = None,
1722-
image_rotary_emb: Optional[torch.Tensor] = None,
1723-
) -> torch.Tensor:
1724-
input_ndim = hidden_states.ndim
1725-
1726-
if input_ndim == 4:
1727-
batch_size, channel, height, width = hidden_states.shape
1728-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1729-
1730-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1731-
1732-
query = attn.to_q(hidden_states)
1733-
if encoder_hidden_states is None:
1734-
encoder_hidden_states = hidden_states
1735-
1736-
key = attn.to_k(encoder_hidden_states)
1737-
value = attn.to_v(encoder_hidden_states)
1738-
1739-
inner_dim = key.shape[-1]
1740-
head_dim = inner_dim // attn.heads
1741-
1742-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1743-
1744-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1745-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1746-
1747-
if attn.norm_q is not None:
1748-
query = attn.norm_q(query)
1749-
if attn.norm_k is not None:
1750-
key = attn.norm_k(key)
1751-
1752-
# Apply RoPE if needed
1753-
if image_rotary_emb is not None:
1754-
# YiYi to-do: update uising apply_rotary_emb
1755-
# from ..embeddings import apply_rotary_emb
1756-
# query = apply_rotary_emb(query, image_rotary_emb)
1757-
# key = apply_rotary_emb(key, image_rotary_emb)
1758-
query, key = apply_rope(query, key, image_rotary_emb)
1759-
1760-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1761-
# TODO: add support for attn.scale when we move to Torch 2.1
1762-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1763-
1764-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1765-
hidden_states = hidden_states.to(query.dtype)
1766-
1767-
if input_ndim == 4:
1768-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1769-
1770-
return hidden_states
1771-
1772-
17731698
class FluxAttnProcessor2_0:
17741699
"""Attention processor used typically in processing the SD3-like self-attention projections."""
17751700

@@ -1785,16 +1710,7 @@ def __call__(
17851710
attention_mask: Optional[torch.FloatTensor] = None,
17861711
image_rotary_emb: Optional[torch.Tensor] = None,
17871712
) -> torch.FloatTensor:
1788-
input_ndim = hidden_states.ndim
1789-
if input_ndim == 4:
1790-
batch_size, channel, height, width = hidden_states.shape
1791-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1792-
context_input_ndim = encoder_hidden_states.ndim
1793-
if context_input_ndim == 4:
1794-
batch_size, channel, height, width = encoder_hidden_states.shape
1795-
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1796-
1797-
batch_size = encoder_hidden_states.shape[0]
1713+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
17981714

17991715
# `sample` projections.
18001716
query = attn.to_q(hidden_states)
@@ -1813,59 +1729,58 @@ def __call__(
18131729
if attn.norm_k is not None:
18141730
key = attn.norm_k(key)
18151731

1816-
# `context` projections.
1817-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1818-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1819-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1732+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1733+
if encoder_hidden_states is not None:
1734+
# `context` projections.
1735+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1736+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1737+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
18201738

1821-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1822-
batch_size, -1, attn.heads, head_dim
1823-
).transpose(1, 2)
1824-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1825-
batch_size, -1, attn.heads, head_dim
1826-
).transpose(1, 2)
1827-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1828-
batch_size, -1, attn.heads, head_dim
1829-
).transpose(1, 2)
1830-
1831-
if attn.norm_added_q is not None:
1832-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1833-
if attn.norm_added_k is not None:
1834-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1739+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1740+
batch_size, -1, attn.heads, head_dim
1741+
).transpose(1, 2)
1742+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1743+
batch_size, -1, attn.heads, head_dim
1744+
).transpose(1, 2)
1745+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1746+
batch_size, -1, attn.heads, head_dim
1747+
).transpose(1, 2)
18351748

1836-
# attention
1837-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1838-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1839-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1749+
if attn.norm_added_q is not None:
1750+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1751+
if attn.norm_added_k is not None:
1752+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1753+
1754+
# attention
1755+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1756+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1757+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
18401758

18411759
if image_rotary_emb is not None:
1842-
# YiYi to-do: update uising apply_rotary_emb
1843-
# from ..embeddings import apply_rotary_emb
1844-
# query = apply_rotary_emb(query, image_rotary_emb)
1845-
# key = apply_rotary_emb(key, image_rotary_emb)
1846-
query, key = apply_rope(query, key, image_rotary_emb)
1760+
from .embeddings import apply_rotary_emb
1761+
1762+
query = apply_rotary_emb(query, image_rotary_emb)
1763+
key = apply_rotary_emb(key, image_rotary_emb)
18471764

18481765
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
18491766
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
18501767
hidden_states = hidden_states.to(query.dtype)
18511768

1852-
encoder_hidden_states, hidden_states = (
1853-
hidden_states[:, : encoder_hidden_states.shape[1]],
1854-
hidden_states[:, encoder_hidden_states.shape[1] :],
1855-
)
1856-
1857-
# linear proj
1858-
hidden_states = attn.to_out[0](hidden_states)
1859-
# dropout
1860-
hidden_states = attn.to_out[1](hidden_states)
1861-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1769+
if encoder_hidden_states is not None:
1770+
encoder_hidden_states, hidden_states = (
1771+
hidden_states[:, : encoder_hidden_states.shape[1]],
1772+
hidden_states[:, encoder_hidden_states.shape[1] :],
1773+
)
18621774

1863-
if input_ndim == 4:
1864-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1865-
if context_input_ndim == 4:
1866-
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1775+
# linear proj
1776+
hidden_states = attn.to_out[0](hidden_states)
1777+
# dropout
1778+
hidden_states = attn.to_out[1](hidden_states)
1779+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
18671780

1868-
return hidden_states, encoder_hidden_states
1781+
return hidden_states, encoder_hidden_states
1782+
else:
1783+
return hidden_states
18691784

18701785

18711786
class XFormersAttnAddedKVProcessor:
@@ -4105,6 +4020,17 @@ def __init__(self):
41054020
pass
41064021

41074022

4023+
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
4024+
r"""
4025+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
4026+
"""
4027+
4028+
def __init__(self):
4029+
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
4030+
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
4031+
super().__init__()
4032+
4033+
41084034
ADDED_KV_ATTENTION_PROCESSORS = (
41094035
AttnAddedKVProcessor,
41104036
SlicedAttnAddedKVProcessor,

src/diffusers/models/controlnet_flux.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from ..models.modeling_utils import ModelMixin
2525
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2626
from .controlnet import BaseOutput, zero_module
27-
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
27+
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
2828
from .modeling_outputs import Transformer2DModelOutput
29-
from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
29+
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
3030

3131

3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -59,7 +59,7 @@ def __init__(
5959
self.out_channels = in_channels
6060
self.inner_dim = num_attention_heads * attention_head_dim
6161

62-
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
62+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
6363
text_time_guidance_cls = (
6464
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
6565
)
@@ -272,8 +272,20 @@ def forward(
272272
)
273273
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
274274

275-
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
276-
ids = torch.cat((txt_ids, img_ids), dim=1)
275+
if txt_ids.ndim == 3:
276+
logger.warning(
277+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
278+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
279+
)
280+
txt_ids = txt_ids[0]
281+
if img_ids.ndim == 3:
282+
logger.warning(
283+
"Passing `img_ids` 3d torch.Tensor is deprecated."
284+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
285+
)
286+
img_ids = img_ids[0]
287+
288+
ids = torch.cat((txt_ids, img_ids), dim=0)
277289
image_rotary_emb = self.pos_embed(ids)
278290

279291
block_samples = ()

src/diffusers/models/embeddings.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed(
446446
linear_factor=1.0,
447447
ntk_factor=1.0,
448448
repeat_interleave_real=True,
449+
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
449450
):
450451
"""
451452
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed(
468469
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
469470
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
470471
Otherwise, they are concateanted with themselves.
472+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
473+
the dtype of the frequency tensor.
471474
Returns:
472475
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
473476
"""
@@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed(
476479
if isinstance(pos, int):
477480
pos = np.arange(pos)
478481
theta = theta * ntk_factor
479-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
482+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
480483
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
481-
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
484+
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
482485
if use_real and repeat_interleave_real:
483-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
484-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
486+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
487+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
485488
return freqs_cos, freqs_sin
486489
elif use_real:
487-
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
488-
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
490+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
491+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
489492
return freqs_cos, freqs_sin
490493
else:
491-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
494+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
492495
return freqs_cis
493496

494497

@@ -540,6 +543,31 @@ def apply_rotary_emb(
540543
return x_out.type_as(x)
541544

542545

546+
class FluxPosEmbed(nn.Module):
547+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
548+
def __init__(self, theta: int, axes_dim: List[int]):
549+
super().__init__()
550+
self.theta = theta
551+
self.axes_dim = axes_dim
552+
553+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
554+
n_axes = ids.shape[-1]
555+
cos_out = []
556+
sin_out = []
557+
pos = ids.squeeze().float().cpu().numpy()
558+
is_mps = ids.device.type == "mps"
559+
freqs_dtype = torch.float32 if is_mps else torch.float64
560+
for i in range(n_axes):
561+
cos, sin = get_1d_rotary_pos_embed(
562+
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
563+
)
564+
cos_out.append(cos)
565+
sin_out.append(sin)
566+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
567+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
568+
return freqs_cos, freqs_sin
569+
570+
543571
class TimestepEmbedding(nn.Module):
544572
def __init__(
545573
self,

0 commit comments

Comments
 (0)