|
22 | 22 | from typing import Any, Callable, List, Optional, Tuple, Union |
23 | 23 |
|
24 | 24 | import torch |
25 | | -from torch import Tensor, device |
| 25 | +from torch import Tensor, device, nn |
26 | 26 |
|
27 | 27 | from .. import __version__ |
28 | 28 | from ..utils import ( |
@@ -646,15 +646,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
646 | 646 | else: # else let accelerate handle loading and dispatching. |
647 | 647 | # Load weights and dispatch according to the device_map |
648 | 648 | # by default the device_map is None and the weights are loaded on the CPU |
649 | | - accelerate.load_checkpoint_and_dispatch( |
650 | | - model, |
651 | | - model_file, |
652 | | - device_map, |
653 | | - max_memory=max_memory, |
654 | | - offload_folder=offload_folder, |
655 | | - offload_state_dict=offload_state_dict, |
656 | | - dtype=torch_dtype, |
657 | | - ) |
| 649 | + try: |
| 650 | + accelerate.load_checkpoint_and_dispatch( |
| 651 | + model, |
| 652 | + model_file, |
| 653 | + device_map, |
| 654 | + max_memory=max_memory, |
| 655 | + offload_folder=offload_folder, |
| 656 | + offload_state_dict=offload_state_dict, |
| 657 | + dtype=torch_dtype, |
| 658 | + ) |
| 659 | + except AttributeError as e: |
| 660 | + # When using accelerate loading, we do not have the ability to load the state |
| 661 | + # dict and rename the weight names manually. Additionally, accelerate skips |
| 662 | + # torch loading conventions and directly writes into `module.{_buffers, _parameters}` |
| 663 | + # (which look like they should be private variables?), so we can't use the standard hooks |
| 664 | + # to rename parameters on load. We need to mimic the original weight names so the correct |
| 665 | + # attributes are available. After we have loaded the weights, we convert the deprecated |
| 666 | + # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert |
| 667 | + # the weights so we don't have to do this again. |
| 668 | + |
| 669 | + if "'Attention' object has no attribute" in str(e): |
| 670 | + logger.warn( |
| 671 | + f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" |
| 672 | + " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" |
| 673 | + " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," |
| 674 | + " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," |
| 675 | + " please also re-upload it or open a PR on the original repository." |
| 676 | + ) |
| 677 | + model._temp_convert_self_to_deprecated_attention_blocks() |
| 678 | + accelerate.load_checkpoint_and_dispatch( |
| 679 | + model, |
| 680 | + model_file, |
| 681 | + device_map, |
| 682 | + max_memory=max_memory, |
| 683 | + offload_folder=offload_folder, |
| 684 | + offload_state_dict=offload_state_dict, |
| 685 | + dtype=torch_dtype, |
| 686 | + ) |
| 687 | + model._undo_temp_convert_self_to_deprecated_attention_blocks() |
| 688 | + else: |
| 689 | + raise e |
658 | 690 |
|
659 | 691 | loading_info = { |
660 | 692 | "missing_keys": [], |
@@ -889,3 +921,53 @@ def recursive_find_attn_block(name, module): |
889 | 921 | state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") |
890 | 922 | if f"{path}.proj_attn.bias" in state_dict: |
891 | 923 | state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") |
| 924 | + |
| 925 | + def _temp_convert_self_to_deprecated_attention_blocks(self): |
| 926 | + deprecated_attention_block_modules = [] |
| 927 | + |
| 928 | + def recursive_find_attn_block(module): |
| 929 | + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: |
| 930 | + deprecated_attention_block_modules.append(module) |
| 931 | + |
| 932 | + for sub_module in module.children(): |
| 933 | + recursive_find_attn_block(sub_module) |
| 934 | + |
| 935 | + recursive_find_attn_block(self) |
| 936 | + |
| 937 | + for module in deprecated_attention_block_modules: |
| 938 | + module.query = module.to_q |
| 939 | + module.key = module.to_k |
| 940 | + module.value = module.to_v |
| 941 | + module.proj_attn = module.to_out[0] |
| 942 | + |
| 943 | + # We don't _have_ to delete the old attributes, but it's helpful to ensure |
| 944 | + # that _all_ the weights are loaded into the new attributes and we're not |
| 945 | + # making an incorrect assumption that this model should be converted when |
| 946 | + # it really shouldn't be. |
| 947 | + del module.to_q |
| 948 | + del module.to_k |
| 949 | + del module.to_v |
| 950 | + del module.to_out |
| 951 | + |
| 952 | + def _undo_temp_convert_self_to_deprecated_attention_blocks(self): |
| 953 | + deprecated_attention_block_modules = [] |
| 954 | + |
| 955 | + def recursive_find_attn_block(module): |
| 956 | + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: |
| 957 | + deprecated_attention_block_modules.append(module) |
| 958 | + |
| 959 | + for sub_module in module.children(): |
| 960 | + recursive_find_attn_block(sub_module) |
| 961 | + |
| 962 | + recursive_find_attn_block(self) |
| 963 | + |
| 964 | + for module in deprecated_attention_block_modules: |
| 965 | + module.to_q = module.query |
| 966 | + module.to_k = module.key |
| 967 | + module.to_v = module.value |
| 968 | + module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) |
| 969 | + |
| 970 | + del module.query |
| 971 | + del module.key |
| 972 | + del module.value |
| 973 | + del module.proj_attn |
0 commit comments