Skip to content

Commit f73ed17

Browse files
camendurupatrickvonplatenpcuenca
authored
Allow converting Flax to PyTorch by adding a "from_flax" keyword (huggingface#1900)
* from_flax * oops * oops * make style with pip install -e ".[dev]" * oops * now code quality happy 😋 * allow_patterns += FLAX_WEIGHTS_NAME * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * for test * bye bye is_flax_available() * oops * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * make style * add test * finihs Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 9147c4c commit f73ed17

File tree

4 files changed

+343
-77
lines changed

4 files changed

+343
-77
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# coding=utf-8
2+
# Copyright 2022 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" PyTorch - Flax general utilities."""
16+
17+
from pickle import UnpicklingError
18+
19+
import numpy as np
20+
21+
import jax
22+
import jax.numpy as jnp
23+
from flax.serialization import from_bytes
24+
from flax.traverse_util import flatten_dict
25+
26+
from ..utils import logging
27+
28+
29+
logger = logging.get_logger(__name__)
30+
31+
32+
#####################
33+
# Flax => PyTorch #
34+
#####################
35+
36+
37+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
38+
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
39+
try:
40+
with open(model_file, "rb") as flax_state_f:
41+
flax_state = from_bytes(None, flax_state_f.read())
42+
except UnpicklingError as e:
43+
try:
44+
with open(model_file) as f:
45+
if f.read().startswith("version"):
46+
raise OSError(
47+
"You seem to have cloned a repository without having git-lfs installed. Please"
48+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
49+
" folder you cloned."
50+
)
51+
else:
52+
raise ValueError from e
53+
except (UnicodeDecodeError, ValueError):
54+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
55+
56+
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
57+
58+
59+
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
60+
"""Load flax checkpoints in a PyTorch model"""
61+
62+
try:
63+
import torch # noqa: F401
64+
except ImportError:
65+
logger.error(
66+
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
67+
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
68+
" instructions."
69+
)
70+
raise
71+
72+
# check if we have bf16 weights
73+
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
74+
if any(is_type_bf16):
75+
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
76+
77+
# and bf16 is not fully supported in PT yet.
78+
logger.warning(
79+
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
80+
"before loading those in PyTorch model."
81+
)
82+
flax_state = jax.tree_util.tree_map(
83+
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
84+
)
85+
86+
pt_model.base_model_prefix = ""
87+
88+
flax_state_dict = flatten_dict(flax_state, sep=".")
89+
pt_model_dict = pt_model.state_dict()
90+
91+
# keep track of unexpected & missing keys
92+
unexpected_keys = []
93+
missing_keys = set(pt_model_dict.keys())
94+
95+
for flax_key_tuple, flax_tensor in flax_state_dict.items():
96+
flax_key_tuple_array = flax_key_tuple.split(".")
97+
98+
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
99+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
100+
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
101+
elif flax_key_tuple_array[-1] == "kernel":
102+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
103+
flax_tensor = flax_tensor.T
104+
elif flax_key_tuple_array[-1] == "scale":
105+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
106+
107+
if "time_embedding" not in flax_key_tuple_array:
108+
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
109+
flax_key_tuple_array[i] = (
110+
flax_key_tuple_string.replace("_0", ".0")
111+
.replace("_1", ".1")
112+
.replace("_2", ".2")
113+
.replace("_3", ".3")
114+
)
115+
116+
flax_key = ".".join(flax_key_tuple_array)
117+
118+
if flax_key in pt_model_dict:
119+
if flax_tensor.shape != pt_model_dict[flax_key].shape:
120+
raise ValueError(
121+
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
122+
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
123+
)
124+
else:
125+
# add weight to pytorch dict
126+
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
127+
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
128+
# remove from missing keys
129+
missing_keys.remove(flax_key)
130+
else:
131+
# weight is not expected by PyTorch model
132+
unexpected_keys.append(flax_key)
133+
134+
pt_model.load_state_dict(pt_model_dict)
135+
136+
# re-transform missing_keys to list
137+
missing_keys = list(missing_keys)
138+
139+
if len(unexpected_keys) > 0:
140+
logger.warning(
141+
"Some weights of the Flax model were not used when initializing the PyTorch model"
142+
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
143+
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
144+
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
145+
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
146+
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
147+
" FlaxBertForSequenceClassification model)."
148+
)
149+
if len(missing_keys) > 0:
150+
logger.warning(
151+
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
152+
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
153+
" use it for predictions and inference."
154+
)
155+
156+
return pt_model

src/diffusers/models/modeling_utils.py

Lines changed: 117 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..utils import (
3131
CONFIG_NAME,
3232
DIFFUSERS_CACHE,
33+
FLAX_WEIGHTS_NAME,
3334
HF_HUB_OFFLINE,
3435
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
3536
SAFETENSORS_WEIGHTS_NAME,
@@ -335,6 +336,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
335336
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
336337
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
337338
identifier allowed by git.
339+
from_flax (`bool`, *optional*, defaults to `False`):
340+
Load the model weights from a Flax checkpoint save file.
338341
subfolder (`str`, *optional*, defaults to `""`):
339342
In case the relevant files are located inside a subfolder of the model repo (either remote in
340343
huggingface.co or downloaded locally), you can specify the folder name here.
@@ -375,6 +378,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
375378
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
376379
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
377380
force_download = kwargs.pop("force_download", False)
381+
from_flax = kwargs.pop("from_flax", False)
378382
resume_download = kwargs.pop("resume_download", False)
379383
proxies = kwargs.pop("proxies", None)
380384
output_loading_info = kwargs.pop("output_loading_info", False)
@@ -433,27 +437,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
433437
# Load model
434438

435439
model_file = None
436-
if is_safetensors_available():
437-
try:
438-
model_file = cls._get_model_file(
439-
pretrained_model_name_or_path,
440-
weights_name=SAFETENSORS_WEIGHTS_NAME,
441-
cache_dir=cache_dir,
442-
force_download=force_download,
443-
resume_download=resume_download,
444-
proxies=proxies,
445-
local_files_only=local_files_only,
446-
use_auth_token=use_auth_token,
447-
revision=revision,
448-
subfolder=subfolder,
449-
user_agent=user_agent,
450-
)
451-
except:
452-
pass
453-
if model_file is None:
440+
if from_flax:
454441
model_file = cls._get_model_file(
455442
pretrained_model_name_or_path,
456-
weights_name=WEIGHTS_NAME,
443+
weights_name=FLAX_WEIGHTS_NAME,
457444
cache_dir=cache_dir,
458445
force_download=force_download,
459446
resume_download=resume_download,
@@ -464,10 +451,105 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
464451
subfolder=subfolder,
465452
user_agent=user_agent,
466453
)
454+
config, unused_kwargs = cls.load_config(
455+
config_path,
456+
cache_dir=cache_dir,
457+
return_unused_kwargs=True,
458+
force_download=force_download,
459+
resume_download=resume_download,
460+
proxies=proxies,
461+
local_files_only=local_files_only,
462+
use_auth_token=use_auth_token,
463+
revision=revision,
464+
subfolder=subfolder,
465+
device_map=device_map,
466+
**kwargs,
467+
)
468+
model = cls.from_config(config, **unused_kwargs)
469+
470+
# Convert the weights
471+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
467472

468-
if low_cpu_mem_usage:
469-
# Instantiate model with empty weights
470-
with accelerate.init_empty_weights():
473+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
474+
else:
475+
if is_safetensors_available():
476+
try:
477+
model_file = cls._get_model_file(
478+
pretrained_model_name_or_path,
479+
weights_name=SAFETENSORS_WEIGHTS_NAME,
480+
cache_dir=cache_dir,
481+
force_download=force_download,
482+
resume_download=resume_download,
483+
proxies=proxies,
484+
local_files_only=local_files_only,
485+
use_auth_token=use_auth_token,
486+
revision=revision,
487+
subfolder=subfolder,
488+
user_agent=user_agent,
489+
)
490+
except:
491+
pass
492+
if model_file is None:
493+
model_file = cls._get_model_file(
494+
pretrained_model_name_or_path,
495+
weights_name=WEIGHTS_NAME,
496+
cache_dir=cache_dir,
497+
force_download=force_download,
498+
resume_download=resume_download,
499+
proxies=proxies,
500+
local_files_only=local_files_only,
501+
use_auth_token=use_auth_token,
502+
revision=revision,
503+
subfolder=subfolder,
504+
user_agent=user_agent,
505+
)
506+
507+
if low_cpu_mem_usage:
508+
# Instantiate model with empty weights
509+
with accelerate.init_empty_weights():
510+
config, unused_kwargs = cls.load_config(
511+
config_path,
512+
cache_dir=cache_dir,
513+
return_unused_kwargs=True,
514+
force_download=force_download,
515+
resume_download=resume_download,
516+
proxies=proxies,
517+
local_files_only=local_files_only,
518+
use_auth_token=use_auth_token,
519+
revision=revision,
520+
subfolder=subfolder,
521+
device_map=device_map,
522+
**kwargs,
523+
)
524+
model = cls.from_config(config, **unused_kwargs)
525+
526+
# if device_map is None, load the state dict and move the params from meta device to the cpu
527+
if device_map is None:
528+
param_device = "cpu"
529+
state_dict = load_state_dict(model_file)
530+
# move the params from meta device to cpu
531+
for param_name, param in state_dict.items():
532+
accepts_dtype = "dtype" in set(
533+
inspect.signature(set_module_tensor_to_device).parameters.keys()
534+
)
535+
if accepts_dtype:
536+
set_module_tensor_to_device(
537+
model, param_name, param_device, value=param, dtype=torch_dtype
538+
)
539+
else:
540+
set_module_tensor_to_device(model, param_name, param_device, value=param)
541+
else: # else let accelerate handle loading and dispatching.
542+
# Load weights and dispatch according to the device_map
543+
# by deafult the device_map is None and the weights are loaded on the CPU
544+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
545+
546+
loading_info = {
547+
"missing_keys": [],
548+
"unexpected_keys": [],
549+
"mismatched_keys": [],
550+
"error_msgs": [],
551+
}
552+
else:
471553
config, unused_kwargs = cls.load_config(
472554
config_path,
473555
cache_dir=cache_dir,
@@ -484,61 +566,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
484566
)
485567
model = cls.from_config(config, **unused_kwargs)
486568

487-
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
488-
if device_map is None:
489-
param_device = "cpu"
490569
state_dict = load_state_dict(model_file)
491-
# move the parms from meta device to cpu
492-
for param_name, param in state_dict.items():
493-
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
494-
if accepts_dtype:
495-
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
496-
else:
497-
set_module_tensor_to_device(model, param_name, param_device, value=param)
498-
else: # else let accelerate handle loading and dispatching.
499-
# Load weights and dispatch according to the device_map
500-
# by deafult the device_map is None and the weights are loaded on the CPU
501-
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
502-
503-
loading_info = {
504-
"missing_keys": [],
505-
"unexpected_keys": [],
506-
"mismatched_keys": [],
507-
"error_msgs": [],
508-
}
509-
else:
510-
config, unused_kwargs = cls.load_config(
511-
config_path,
512-
cache_dir=cache_dir,
513-
return_unused_kwargs=True,
514-
force_download=force_download,
515-
resume_download=resume_download,
516-
proxies=proxies,
517-
local_files_only=local_files_only,
518-
use_auth_token=use_auth_token,
519-
revision=revision,
520-
subfolder=subfolder,
521-
device_map=device_map,
522-
**kwargs,
523-
)
524-
model = cls.from_config(config, **unused_kwargs)
525570

526-
state_dict = load_state_dict(model_file)
527-
528-
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
529-
model,
530-
state_dict,
531-
model_file,
532-
pretrained_model_name_or_path,
533-
ignore_mismatched_sizes=ignore_mismatched_sizes,
534-
)
571+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
572+
model,
573+
state_dict,
574+
model_file,
575+
pretrained_model_name_or_path,
576+
ignore_mismatched_sizes=ignore_mismatched_sizes,
577+
)
535578

536-
loading_info = {
537-
"missing_keys": missing_keys,
538-
"unexpected_keys": unexpected_keys,
539-
"mismatched_keys": mismatched_keys,
540-
"error_msgs": error_msgs,
541-
}
579+
loading_info = {
580+
"missing_keys": missing_keys,
581+
"unexpected_keys": unexpected_keys,
582+
"mismatched_keys": mismatched_keys,
583+
"error_msgs": error_msgs,
584+
}
542585

543586
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
544587
raise ValueError(

0 commit comments

Comments
 (0)