Skip to content

Commit 20ce68f

Browse files
Fix dtype model loading (huggingface#1449)
* Add test * up * no bfloat16 for mps * fix * rename test
1 parent 110ffe2 commit 20ce68f

File tree

7 files changed

+47
-14
lines changed

7 files changed

+47
-14
lines changed

src/diffusers/modeling_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
472472
model = cls.from_config(config, **unused_kwargs)
473473

474474
state_dict = load_state_dict(model_file)
475+
dtype = set(v.dtype for v in state_dict.values())
476+
477+
if len(dtype) > 1 and torch.float32 not in dtype:
478+
raise ValueError(
479+
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
480+
f" make sure that {model_file} weights have only one dtype."
481+
)
482+
elif len(dtype) > 1 and torch.float32 in dtype:
483+
dtype = torch.float32
484+
else:
485+
dtype = dtype.pop()
486+
487+
# move model to correct dtype
488+
model = model.to(dtype)
489+
475490
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
476491
model,
477492
state_dict,

tests/models/test_models_unet_1d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def test_outputs_equivalence(self):
6363
super().test_outputs_equivalence()
6464

6565
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
66-
def test_from_pretrained_save_pretrained(self):
67-
super().test_from_pretrained_save_pretrained()
66+
def test_from_save_pretrained(self):
67+
super().test_from_save_pretrained()
6868

6969
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
7070
def test_model_from_pretrained(self):
@@ -183,8 +183,8 @@ def test_outputs_equivalence(self):
183183
super().test_outputs_equivalence()
184184

185185
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
186-
def test_from_pretrained_save_pretrained(self):
187-
super().test_from_pretrained_save_pretrained()
186+
def test_from_save_pretrained(self):
187+
super().test_from_save_pretrained()
188188

189189
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
190190
def test_model_from_pretrained(self):

tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def tearDown(self):
4242
gc.collect()
4343
torch.cuda.empty_cache()
4444

45-
def test_from_pretrained_save_pretrained(self):
45+
def test_from_save_pretrained(self):
4646
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
4747
pipe.to(torch_device)
4848
pipe.set_progress_bar_config(disable=None)

tests/test_modeling_common.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
class ModelTesterMixin:
30-
def test_from_pretrained_save_pretrained(self):
30+
def test_from_save_pretrained(self):
3131
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
3232

3333
model = self.model_class(**init_dict)
@@ -57,6 +57,24 @@ def test_from_pretrained_save_pretrained(self):
5757
max_diff = (image - new_image).abs().sum().item()
5858
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
5959

60+
def test_from_save_pretrained_dtype(self):
61+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
62+
63+
model = self.model_class(**init_dict)
64+
model.to(torch_device)
65+
model.eval()
66+
67+
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
68+
if torch_device == "mps" and dtype == torch.bfloat16:
69+
continue
70+
with tempfile.TemporaryDirectory() as tmpdirname:
71+
model.to(dtype)
72+
model.save_pretrained(tmpdirname)
73+
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
74+
assert new_model.dtype == dtype
75+
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
76+
assert new_model.dtype == dtype
77+
6078
def test_determinism(self):
6179
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
6280
model = self.model_class(**init_dict)

tests/test_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def test_warning_unused_kwargs(self):
659659
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n"
660660
)
661661

662-
def test_from_pretrained_save_pretrained(self):
662+
def test_from_save_pretrained(self):
663663
# 1. Load models
664664
model = UNet2DModel(
665665
block_out_channels=(32, 64),

tests/test_scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
334334

335335
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
336336

337-
def test_from_pretrained_save_pretrained(self):
337+
def test_from_save_pretrained(self):
338338
kwargs = dict(self.forward_default_kwargs)
339339

340340
num_inference_steps = kwargs.pop("num_inference_steps", None)
@@ -875,7 +875,7 @@ def check_over_configs(self, time_step=0, **config):
875875

876876
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
877877

878-
def test_from_pretrained_save_pretrained(self):
878+
def test_from_save_pretrained(self):
879879
pass
880880

881881
def check_over_forward(self, time_step=0, **forward_kwargs):
@@ -1068,7 +1068,7 @@ def check_over_configs(self, time_step=0, **config):
10681068

10691069
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
10701070

1071-
def test_from_pretrained_save_pretrained(self):
1071+
def test_from_save_pretrained(self):
10721072
pass
10731073

10741074
def check_over_forward(self, time_step=0, **forward_kwargs):
@@ -1745,7 +1745,7 @@ def check_over_configs(self, time_step=0, **config):
17451745

17461746
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
17471747

1748-
def test_from_pretrained_save_pretrained(self):
1748+
def test_from_save_pretrained(self):
17491749
pass
17501750

17511751
def check_over_forward(self, time_step=0, **forward_kwargs):

tests/test_scheduler_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
126126

127127
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
128128

129-
def test_from_pretrained_save_pretrained(self):
129+
def test_from_save_pretrained(self):
130130
kwargs = dict(self.forward_default_kwargs)
131131

132132
num_inference_steps = kwargs.pop("num_inference_steps", None)
@@ -408,7 +408,7 @@ def check_over_configs(self, time_step=0, **config):
408408

409409
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
410410

411-
def test_from_pretrained_save_pretrained(self):
411+
def test_from_save_pretrained(self):
412412
kwargs = dict(self.forward_default_kwargs)
413413

414414
num_inference_steps = kwargs.pop("num_inference_steps", None)
@@ -690,7 +690,7 @@ def check_over_configs(self, time_step=0, **config):
690690

691691
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
692692

693-
def test_from_pretrained_save_pretrained(self):
693+
def test_from_save_pretrained(self):
694694
pass
695695

696696
def test_scheduler_outputs_equivalence(self):

0 commit comments

Comments
 (0)