Skip to content

Commit 429e544

Browse files
Add attention mask to uclip (huggingface#1756)
* Remove bogus file * [Unclip] Add efficient attention * [Unclip] Add efficient attention
1 parent dc7cd89 commit 429e544

File tree

3 files changed

+16
-42
lines changed

3 files changed

+16
-42
lines changed

src/diffusers/models/attention.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
621621
key = self.reshape_heads_to_batch_dim(key)
622622
value = self.reshape_heads_to_batch_dim(value)
623623

624+
if attention_mask is not None:
625+
if attention_mask.shape[-1] != query.shape[1]:
626+
target_length = query.shape[1]
627+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
628+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
629+
624630
# attention, what we cannot get enough of
625631
if self._use_memory_efficient_attention_xformers:
626632
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
@@ -630,7 +636,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
630636
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
631637
hidden_states = self._attention(query, key, value, attention_mask)
632638
else:
633-
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
639+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
634640

635641
# linear proj
636642
hidden_states = self.to_out[0](hidden_states)
@@ -653,11 +659,6 @@ def _attention(self, query, key, value, attention_mask=None):
653659
)
654660

655661
if attention_mask is not None:
656-
if attention_mask.shape != attention_scores.shape:
657-
target_length = query.shape[1]
658-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
659-
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
660-
661662
attention_scores = attention_scores + attention_mask
662663

663664
if self.upcast_softmax:
@@ -675,7 +676,7 @@ def _attention(self, query, key, value, attention_mask=None):
675676
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
676677
return hidden_states
677678

678-
def _sliced_attention(self, query, key, value, sequence_length, dim):
679+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
679680
batch_size_attention = query.shape[0]
680681
hidden_states = torch.zeros(
681682
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
@@ -699,6 +700,13 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
699700
beta=0,
700701
alpha=self.scale,
701702
)
703+
704+
if attention_mask is not None:
705+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
706+
707+
if self.upcast_softmax:
708+
attn_slice = attn_slice.float()
709+
702710
attn_slice = attn_slice.softmax(dim=-1)
703711

704712
# cast back to the original dtype
@@ -716,7 +724,7 @@ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask
716724
query = query.contiguous()
717725
key = key.contiguous()
718726
value = value.contiguous()
719-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
727+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
720728
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
721729
return hidden_states
722730

src/diffusers/models/unet_2d_blocks.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -564,23 +564,6 @@ def __init__(
564564
self.attentions = nn.ModuleList(attentions)
565565
self.resnets = nn.ModuleList(resnets)
566566

567-
def set_attention_slice(self, slice_size):
568-
head_dims = self.attn_num_head_channels
569-
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
570-
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
571-
raise ValueError(
572-
f"Make sure slice_size {slice_size} is a common divisor of "
573-
f"the number of heads used in cross_attention: {head_dims}"
574-
)
575-
if slice_size is not None and slice_size > min(head_dims):
576-
raise ValueError(
577-
f"slice_size {slice_size} has to be smaller or equal to "
578-
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
579-
)
580-
581-
for attn in self.attentions:
582-
attn._set_attention_slice(slice_size)
583-
584567
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
585568
hidden_states = self.resnets[0](hidden_states, temb)
586569
for attn, resnet in zip(self.attentions, self.resnets[1:]):

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,23 +1250,6 @@ def __init__(
12501250
self.attentions = nn.ModuleList(attentions)
12511251
self.resnets = nn.ModuleList(resnets)
12521252

1253-
def set_attention_slice(self, slice_size):
1254-
head_dims = self.attn_num_head_channels
1255-
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
1256-
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
1257-
raise ValueError(
1258-
f"Make sure slice_size {slice_size} is a common divisor of "
1259-
f"the number of heads used in cross_attention: {head_dims}"
1260-
)
1261-
if slice_size is not None and slice_size > min(head_dims):
1262-
raise ValueError(
1263-
f"slice_size {slice_size} has to be smaller or equal to "
1264-
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
1265-
)
1266-
1267-
for attn in self.attentions:
1268-
attn._set_attention_slice(slice_size)
1269-
12701253
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
12711254
hidden_states = self.resnets[0](hidden_states, temb)
12721255
for attn, resnet in zip(self.attentions, self.resnets[1:]):

0 commit comments

Comments
 (0)