Skip to content

Commit 59aefe9

Browse files
device map legacy attention block weight conversion (huggingface#3804)
1 parent 3ddc2b7 commit 59aefe9

File tree

3 files changed

+137
-10
lines changed

3 files changed

+137
-10
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
self.upcast_softmax = upcast_softmax
7979
self.rescale_output_factor = rescale_output_factor
8080
self.residual_connection = residual_connection
81+
self.dropout = dropout
8182

8283
# we make use of this private variable to know whether this class is loaded
8384
# with an deprecated state dict so that we can convert it on the fly

src/diffusers/models/modeling_utils.py

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Any, Callable, List, Optional, Tuple, Union
2323

2424
import torch
25-
from torch import Tensor, device
25+
from torch import Tensor, device, nn
2626

2727
from .. import __version__
2828
from ..utils import (
@@ -646,15 +646,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
646646
else: # else let accelerate handle loading and dispatching.
647647
# Load weights and dispatch according to the device_map
648648
# by default the device_map is None and the weights are loaded on the CPU
649-
accelerate.load_checkpoint_and_dispatch(
650-
model,
651-
model_file,
652-
device_map,
653-
max_memory=max_memory,
654-
offload_folder=offload_folder,
655-
offload_state_dict=offload_state_dict,
656-
dtype=torch_dtype,
657-
)
649+
try:
650+
accelerate.load_checkpoint_and_dispatch(
651+
model,
652+
model_file,
653+
device_map,
654+
max_memory=max_memory,
655+
offload_folder=offload_folder,
656+
offload_state_dict=offload_state_dict,
657+
dtype=torch_dtype,
658+
)
659+
except AttributeError as e:
660+
# When using accelerate loading, we do not have the ability to load the state
661+
# dict and rename the weight names manually. Additionally, accelerate skips
662+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
663+
# (which look like they should be private variables?), so we can't use the standard hooks
664+
# to rename parameters on load. We need to mimic the original weight names so the correct
665+
# attributes are available. After we have loaded the weights, we convert the deprecated
666+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
667+
# the weights so we don't have to do this again.
668+
669+
if "'Attention' object has no attribute" in str(e):
670+
logger.warn(
671+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
672+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
673+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
674+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
675+
" please also re-upload it or open a PR on the original repository."
676+
)
677+
model._temp_convert_self_to_deprecated_attention_blocks()
678+
accelerate.load_checkpoint_and_dispatch(
679+
model,
680+
model_file,
681+
device_map,
682+
max_memory=max_memory,
683+
offload_folder=offload_folder,
684+
offload_state_dict=offload_state_dict,
685+
dtype=torch_dtype,
686+
)
687+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
688+
else:
689+
raise e
658690

659691
loading_info = {
660692
"missing_keys": [],
@@ -889,3 +921,53 @@ def recursive_find_attn_block(name, module):
889921
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
890922
if f"{path}.proj_attn.bias" in state_dict:
891923
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
924+
925+
def _temp_convert_self_to_deprecated_attention_blocks(self):
926+
deprecated_attention_block_modules = []
927+
928+
def recursive_find_attn_block(module):
929+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
930+
deprecated_attention_block_modules.append(module)
931+
932+
for sub_module in module.children():
933+
recursive_find_attn_block(sub_module)
934+
935+
recursive_find_attn_block(self)
936+
937+
for module in deprecated_attention_block_modules:
938+
module.query = module.to_q
939+
module.key = module.to_k
940+
module.value = module.to_v
941+
module.proj_attn = module.to_out[0]
942+
943+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
944+
# that _all_ the weights are loaded into the new attributes and we're not
945+
# making an incorrect assumption that this model should be converted when
946+
# it really shouldn't be.
947+
del module.to_q
948+
del module.to_k
949+
del module.to_v
950+
del module.to_out
951+
952+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
953+
deprecated_attention_block_modules = []
954+
955+
def recursive_find_attn_block(module):
956+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
957+
deprecated_attention_block_modules.append(module)
958+
959+
for sub_module in module.children():
960+
recursive_find_attn_block(sub_module)
961+
962+
recursive_find_attn_block(self)
963+
964+
for module in deprecated_attention_block_modules:
965+
module.to_q = module.query
966+
module.to_k = module.key
967+
module.to_v = module.value
968+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
969+
970+
del module.query
971+
del module.key
972+
del module.value
973+
del module.proj_attn

tests/models/test_attention_processor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import tempfile
12
import unittest
23

4+
import numpy as np
35
import torch
46

7+
from diffusers import DiffusionPipeline
58
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
69

710

@@ -73,3 +76,44 @@ def test_only_cross_attention(self):
7376
only_cross_attn_out = attn(**forward_args)
7477

7578
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())
79+
80+
81+
class DeprecatedAttentionBlockTests(unittest.TestCase):
82+
def test_conversion_when_using_device_map(self):
83+
pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None)
84+
85+
pre_conversion = pipe(
86+
"foo",
87+
num_inference_steps=2,
88+
generator=torch.Generator("cpu").manual_seed(0),
89+
output_type="np",
90+
).images
91+
92+
# the initial conversion succeeds
93+
pipe = DiffusionPipeline.from_pretrained(
94+
"hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None
95+
)
96+
97+
conversion = pipe(
98+
"foo",
99+
num_inference_steps=2,
100+
generator=torch.Generator("cpu").manual_seed(0),
101+
output_type="np",
102+
).images
103+
104+
with tempfile.TemporaryDirectory() as tmpdir:
105+
# save the converted model
106+
pipe.save_pretrained(tmpdir)
107+
108+
# can also load the converted weights
109+
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None)
110+
111+
after_conversion = pipe(
112+
"foo",
113+
num_inference_steps=2,
114+
generator=torch.Generator("cpu").manual_seed(0),
115+
output_type="np",
116+
).images
117+
118+
self.assertTrue(np.allclose(pre_conversion, conversion))
119+
self.assertTrue(np.allclose(conversion, after_conversion))

0 commit comments

Comments
 (0)