Skip to content

Commit db5fa43

Browse files
[Loading] allow modules to be loaded in fp16 (huggingface#230)
1 parent 0ab9485 commit db5fa43

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/diffusers/modeling_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
315315
use_auth_token = kwargs.pop("use_auth_token", None)
316316
revision = kwargs.pop("revision", None)
317317
from_auto_class = kwargs.pop("_from_auto", False)
318+
torch_dtype = kwargs.pop("torch_dtype", None)
318319
subfolder = kwargs.pop("subfolder", None)
319320

320321
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -334,6 +335,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
334335
subfolder=subfolder,
335336
**kwargs,
336337
)
338+
339+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
340+
raise ValueError(f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.")
341+
elif torch_dtype is not None:
342+
model = model.to(torch_dtype)
343+
337344
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
338345
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
339346
# Load model

src/diffusers/pipeline_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
146146
local_files_only = kwargs.pop("local_files_only", False)
147147
use_auth_token = kwargs.pop("use_auth_token", None)
148148
revision = kwargs.pop("revision", None)
149+
torch_dtype = kwargs.pop("torch_dtype", None)
149150

150151
# 1. Download the checkpoints and configs
151152
# use snapshot download here to get it working from from_pretrained
@@ -237,12 +238,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
237238

238239
load_method = getattr(class_obj, load_method_name)
239240

241+
loading_kwargs = {}
242+
if issubclass(class_obj, torch.nn.Module):
243+
loading_kwargs["torch_dtype"] = torch_dtype
244+
240245
# check if the module is in a subdirectory
241246
if os.path.isdir(os.path.join(cached_folder, name)):
242-
loaded_sub_model = load_method(os.path.join(cached_folder, name))
247+
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
243248
else:
244249
# else load from the root directory
245-
loaded_sub_model = load_method(cached_folder)
250+
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
246251

247252
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
248253

0 commit comments

Comments
 (0)