-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Correct bad attn naming #3797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct bad attn naming #3797
Conversation
The documentation is not available anymore as the PR was closed or merged. |
# The reason for this behavior is to correct for incorrectly named variables that were introduced | ||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 | ||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000 configurations is too backwards breaking | ||
# which is why we correct for the naming here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
admitting naming mess-up 😅
@@ -35,7 +35,7 @@ def get_down_block( | |||
add_downsample, | |||
resnet_eps, | |||
resnet_act_fn, | |||
attn_num_head_channels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renaming is fine here as none of the classes are public classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good that you identified this still relatively early.
If you can, please share tips for introducing these kinds of changes efficiently.
heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, | ||
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, | ||
heads=in_channels // attn_num_heads if attn_num_heads is not None else 1, | ||
dim_head=attn_num_heads if attn_num_heads is not None else in_channels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my head is 🤯 a little bit now reading this PR so I could be wrong
but from what I understand the attn_num_head_channels
has been passed correctly as dim_head
for all attention classes except for Transformer2D
and this change will cause unexpected behavior when we use the new num_attention_heads
argument from the UNet2DConditionModel
for other attention class
Taking AttnDownBlock2D
as example here
using attention_head_dim
argument will return expected result
from diffusers import UNet2DConditionModel
down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
attention_head_dim = 16,
block_out_channels = (320,),
down_block_types = down_block_types,
up_block_types = up_block_types)
# this prints 20
unet.down_blocks[0].attentions[0].heads
using num_attention_heads
argument will return the wrong results
down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
num_attention_heads = 16,
block_out_channels = (320,),
down_block_types = down_block_types,
up_block_types = up_block_types)
# this prints 20, which is wrong? should be 16
unet.down_blocks[0].attentions[0].heads
maybe we need to path both num_attention_heads
and attention_head_dim
to the blocks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're 100% right great catch! I think there was a double incorrect naming correcting itself here haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't I expect any mapping between what the user specifies as num_attention_heads
and attention_head_dim
and how it's handled internally?
To put this perspective,
from diffusers import UNet2DConditionModel
down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
num_attention_heads = 16,
block_out_channels = (320,),
down_block_types = down_block_types,
up_block_types = up_block_types
)
# this prints 40
unet.down_blocks[0].attentions[0].heads
and
down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
attention_head_dim = 16,
block_out_channels = (320,),
down_block_types = down_block_types,
up_block_types = up_block_types)
# this prints 20
unet.down_blocks[0].attentions[0].heads
Are all these expected? Like shouldn't unet.down_blocks[0].attentions[0].heads
print 16 unless I am missing out on something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just in case you don't miss it, pinging @patrickvonplaten (apologies in advance).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul num_attention_heads
is not used for AttnDownBlock2D
, only for CrossAttnDownBlock2D
. If you run your first test with that type of block, num_attention_heads
would be 16 as expected. But yes, I understand it can be confusing. Not sure how we can deal with it, perhaps we can follow up in a new PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm immediately, I don't have any idea to mitigate that faithfully, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca i thought both AttnDownBlock2D
and CrossAttnDownBlock2D
need this argument no? anything with the Attention
class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did another pass and left some comments.
How can we best test the changes to ensure robustness here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting PR and issue! Asked some clarifications on a couple of details :)
@@ -219,7 +228,7 @@ def __init__( | |||
resnet_act_fn=act_fn, | |||
resnet_groups=norm_num_groups, | |||
cross_attention_dim=cross_attention_dim, | |||
attn_num_head_channels=attention_head_dim[i], | |||
num_attention_heads=num_attention_heads[i], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we pass attn_head_dim = attention_head_dim
here too? We are ignoring it (replaced with num_attention_heads
) but then get_down_block
will complain that it's recommended to passattn_head_dim
and default to copying it from num_attention_heads
. We have all the information at this point in the caller.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also make naming consistent when we can (attn_head_dim
vs attention_head_dim
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good point!
): | ||
# If attn head dim is not defined, we default it to the number of heads | ||
if attn_head_dim is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if attn_head_dim is None: | |
if attn_head_dim is None and num_attention_heads is not None: |
When we call with None
(i.e., in the vae), there's no point in showing the warning imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not pass None
anymore to get_up_blocks
and get_down_blocks
IMO. I'm correcting the VAE here
@@ -221,7 +233,8 @@ def __init__( | |||
resnet_act_fn=act_fn, | |||
resnet_groups=norm_num_groups, | |||
cross_attention_dim=cross_attention_dim, | |||
attn_num_head_channels=attention_head_dim[i], | |||
num_attention_heads=num_attention_heads[i], | |||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure we never pass attention_head_dim[i] = None
to the get_up_block / get_down_block function. This reduces the black magic in the block code and makes it easier for the reader to understand how things are defined for SD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I think the latest changes are easier and require less warnings.
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Outdated
Show resolved
Hide resolved
@@ -398,6 +417,7 @@ def __init__( | |||
resnet_skip_time_act=resnet_skip_time_act, | |||
resnet_out_scale_factor=resnet_out_scale_factor, | |||
cross_attention_norm=cross_attention_norm, | |||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
attention_head_dim=output_channel //num_attention_heads[i], |
so not sure what this is intended to do here, I think attention_head_dim
would never be None
in our case, so the default 8
would just be passed down as it is and used to calculate the number of attention heads
anyways this prints out 40
, (320//8), I think if we pass num_attention_heads = 16
we would want to see the number of attentions to be heads to be 16
down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
num_attention_heads = 16,
block_out_channels = (320,),
down_block_types = down_block_types,
up_block_types = up_block_types)
# this prints 40
unet.down_blocks[0].attentions[0].heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that by default attention_head_dim
is defined to be 8
and it takes priority for classes such as AttnDownBlock2D
as mentioned by @pcuenca in #3797
However this is indeed confusing as now we have some attention blocks where num_attention_heads
take priority, e.g. the cross attention blocks and some where number of attention heads take priority. I will clean this up in a follow-up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu , this:
attention_head_dim=output_channel //num_attention_heads[i],
would break current behavior, e.g. configs that have attention_head_dim
set to None and would then pass an incorrect number here
@@ -59,7 +59,7 @@ def prepare_init_args_and_inputs_for_common(self): | |||
"block_out_channels": (32, 64), | |||
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"), | |||
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"), | |||
"attention_head_dim": None, | |||
"attention_head_dim": 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to also test for the new argument (num_attention_heads
) we introduced here to check for feature parity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll add some tests now for num_attention_heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Pedro's observations from here.
I think Pedro left a couple of nits and I had a question on testing the new argument (num_attention_heads
) for feature parity. Other than those, looks good to me!
* relax tolerance slightly * correct incorrect naming * correct namingc * correct more * Apply suggestions from code review * Fix more * Correct more * correct incorrect naming * Update src/diffusers/models/controlnet.py * Correct flax * Correct renaming * Correct blocks * Fix more * Correct more * mkae style * mkae style * mkae style * mkae style * mkae style * Fix flax * mkae style * rename * rename * rename attn head dim to attention_head_dim * correct flax * make style * improve * Correct more * make style * fix more * mkae style * Update src/diffusers/models/controlnet_flax.py * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
* relax tolerance slightly * correct incorrect naming * correct namingc * correct more * Apply suggestions from code review * Fix more * Correct more * correct incorrect naming * Update src/diffusers/models/controlnet.py * Correct flax * Correct renaming * Correct blocks * Fix more * Correct more * mkae style * mkae style * mkae style * mkae style * mkae style * Fix flax * mkae style * rename * rename * rename attn head dim to attention_head_dim * correct flax * make style * improve * Correct more * make style * fix more * mkae style * Update src/diffusers/models/controlnet_flax.py * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]>
This PR corrects incorrect variable usage / naming as discovered in #2011 in a non-breaking way