Skip to content

Commit 8267c78

Browse files
[Loading] Better error message on missing keys (huggingface#2198)
* up * finish
1 parent 4fc7084 commit 8267c78

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
541541
param_device = "cpu"
542542
state_dict = load_state_dict(model_file)
543543
# move the params from meta device to cpu
544+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
545+
if len(missing_keys) > 0:
546+
raise ValueError(
547+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
548+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
549+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
550+
" those weights or else make sure your checkpoint file is correct."
551+
)
552+
544553
for param_name, param in state_dict.items():
545554
accepts_dtype = "dtype" in set(
546555
inspect.signature(set_module_tensor_to_device).parameters.keys()

tests/test_modeling_common.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,20 @@
2121
import numpy as np
2222
import torch
2323

24-
from diffusers.models import ModelMixin
24+
from diffusers.models import ModelMixin, UNet2DConditionModel
2525
from diffusers.training_utils import EMAModel
2626
from diffusers.utils import torch_device
2727

2828

29+
class ModelUtilsTest(unittest.TestCase):
30+
def test_accelerate_loading_error_message(self):
31+
with self.assertRaises(ValueError) as error_context:
32+
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
33+
34+
# make sure that error message states what keys are missing
35+
assert "conv_out.bias" in str(error_context.exception)
36+
37+
2938
class ModelTesterMixin:
3039
def test_from_save_pretrained(self):
3140
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)