Skip to content

Correct bad attn naming #3797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Jun 22, 2023
Merged

Correct bad attn naming #3797

merged 41 commits into from
Jun 22, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jun 15, 2023

This PR corrects incorrect variable usage / naming as discovered in #2011 in a non-breaking way

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 15, 2023

The documentation is not available anymore as the PR was closed or merged.

# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000 configurations is too backwards breaking
# which is why we correct for the naming here.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

admitting naming mess-up 😅

@@ -35,7 +35,7 @@ def get_down_block(
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming is fine here as none of the classes are public classes

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good that you identified this still relatively early.

If you can, please share tips for introducing these kinds of changes efficiently.

heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
heads=in_channels // attn_num_heads if attn_num_heads is not None else 1,
dim_head=attn_num_heads if attn_num_heads is not None else in_channels,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my head is 🤯 a little bit now reading this PR so I could be wrong

but from what I understand the attn_num_head_channels has been passed correctly as dim_head for all attention classes except for Transformer2D and this change will cause unexpected behavior when we use the new num_attention_heads argument from the UNet2DConditionModel for other attention class

Taking AttnDownBlock2D as example here

using attention_head_dim argument will return expected result

from diffusers import UNet2DConditionModel
down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
    attention_head_dim = 16,
    block_out_channels = (320,),
    down_block_types = down_block_types,
    up_block_types = up_block_types)
# this prints 20
unet.down_blocks[0].attentions[0].heads

using num_attention_heads argument will return the wrong results

down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
    num_attention_heads = 16,
    block_out_channels = (320,),
    down_block_types = down_block_types,
    up_block_types = up_block_types)
# this prints 20, which is wrong? should be 16
unet.down_blocks[0].attentions[0].heads

maybe we need to path both num_attention_heads and attention_head_dim to the blocks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're 100% right great catch! I think there was a double incorrect naming correcting itself here haha

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't I expect any mapping between what the user specifies as num_attention_heads and attention_head_dim and how it's handled internally?

To put this perspective,

from diffusers import UNet2DConditionModel

down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
    num_attention_heads = 16,
    block_out_channels = (320,),
    down_block_types = down_block_types,
    up_block_types = up_block_types
)

# this prints 40
unet.down_blocks[0].attentions[0].heads

and

down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
    attention_head_dim = 16,
    block_out_channels = (320,),
    down_block_types = down_block_types,
    up_block_types = up_block_types)
# this prints 20
unet.down_blocks[0].attentions[0].heads

Are all these expected? Like shouldn't unet.down_blocks[0].attentions[0].heads print 16 unless I am missing out on something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case you don't miss it, pinging @patrickvonplaten (apologies in advance).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul num_attention_heads is not used for AttnDownBlock2D, only for CrossAttnDownBlock2D. If you run your first test with that type of block, num_attention_heads would be 16 as expected. But yes, I understand it can be confusing. Not sure how we can deal with it, perhaps we can follow up in a new PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm immediately, I don't have any idea to mitigate that faithfully, though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca i thought both AttnDownBlock2D and CrossAttnDownBlock2D need this argument no? anything with the Attention class

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did another pass and left some comments.

How can we best test the changes to ensure robustness here?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting PR and issue! Asked some clarifications on a couple of details :)

@@ -219,7 +228,7 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we pass attn_head_dim = attention_head_dim here too? We are ignoring it (replaced with num_attention_heads) but then get_down_block will complain that it's recommended to passattn_head_dim and default to copying it from num_attention_heads. We have all the information at this point in the caller.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also make naming consistent when we can (attn_head_dim vs attention_head_dim)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good point!

):
# If attn head dim is not defined, we default it to the number of heads
if attn_head_dim is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn_head_dim is None:
if attn_head_dim is None and num_attention_heads is not None:

When we call with None (i.e., in the vae), there's no point in showing the warning imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not pass None anymore to get_up_blocks and get_down_blocks IMO. I'm correcting the VAE here

@@ -221,7 +233,8 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure we never pass attention_head_dim[i] = None to the get_up_block / get_down_block function. This reduces the black magic in the block code and makes it easier for the reader to understand how things are defined for SD

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I think the latest changes are easier and require less warnings.

@@ -398,6 +417,7 @@ def __init__(
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=output_channel //num_attention_heads[i],

so not sure what this is intended to do here, I think attention_head_dim would never be None in our case, so the default 8 would just be passed down as it is and used to calculate the number of attention heads

anyways this prints out 40, (320//8), I think if we pass num_attention_heads = 16 we would want to see the number of attentions to be heads to be 16

down_block_types = ("AttnDownBlock2D",)
up_block_types = ("AttnUpBlock2D",)
unet = UNet2DConditionModel(
    num_attention_heads = 16,
    block_out_channels = (320,),
    down_block_types = down_block_types,
    up_block_types = up_block_types)
# this prints 40
unet.down_blocks[0].attentions[0].heads

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that by default attention_head_dim is defined to be 8 and it takes priority for classes such as AttnDownBlock2D as mentioned by @pcuenca in #3797

However this is indeed confusing as now we have some attention blocks where num_attention_heads take priority, e.g. the cross attention blocks and some where number of attention heads take priority. I will clean this up in a follow-up PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu , this:

attention_head_dim=output_channel //num_attention_heads[i],

would break current behavior, e.g. configs that have attention_head_dim set to None and would then pass an incorrect number here

@@ -59,7 +59,7 @@ def prepare_init_args_and_inputs_for_common(self):
"block_out_channels": (32, 64),
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": None,
"attention_head_dim": 3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to also test for the new argument (num_attention_heads) we introduced here to check for feature parity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll add some tests now for num_attention_heads

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Pedro's observations from here.

I think Pedro left a couple of nits and I had a question on testing the new argument (num_attention_heads) for feature parity. Other than those, looks good to me!

@patrickvonplaten patrickvonplaten merged commit 88d2694 into main Jun 22, 2023
@patrickvonplaten patrickvonplaten deleted the correct_bad_attn_naming branch June 22, 2023 11:53
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* relax tolerance slightly

* correct incorrect naming

* correct namingc

* correct more

* Apply suggestions from code review

* Fix more

* Correct more

* correct incorrect naming

* Update src/diffusers/models/controlnet.py

* Correct flax

* Correct renaming

* Correct blocks

* Fix more

* Correct more

* mkae style

* mkae style

* mkae style

* mkae style

* mkae style

* Fix flax

* mkae style

* rename

* rename

* rename attn head dim to attention_head_dim

* correct flax

* make style

* improve

* Correct more

* make style

* fix more

* mkae style

* Update src/diffusers/models/controlnet_flax.py

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Co-authored-by: Pedro Cuenca <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* relax tolerance slightly

* correct incorrect naming

* correct namingc

* correct more

* Apply suggestions from code review

* Fix more

* Correct more

* correct incorrect naming

* Update src/diffusers/models/controlnet.py

* Correct flax

* Correct renaming

* Correct blocks

* Fix more

* Correct more

* mkae style

* mkae style

* mkae style

* mkae style

* mkae style

* Fix flax

* mkae style

* rename

* rename

* rename attn head dim to attention_head_dim

* correct flax

* make style

* improve

* Correct more

* make style

* fix more

* mkae style

* Update src/diffusers/models/controlnet_flax.py

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Co-authored-by: Pedro Cuenca <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants