2424 AutoFeatureExtractor ,
2525 BertTokenizerFast ,
2626 CLIPImageProcessor ,
27+ CLIPTextConfig ,
2728 CLIPTextModel ,
2829 CLIPTextModelWithProjection ,
2930 CLIPTokenizer ,
4849 PNDMScheduler ,
4950 UnCLIPScheduler ,
5051)
51- from ...utils import is_omegaconf_available , is_safetensors_available , logging
52+ from ...utils import is_accelerate_available , is_omegaconf_available , is_safetensors_available , logging
5253from ...utils .import_utils import BACKENDS_MAPPING
5354from ..latent_diffusion .pipeline_latent_diffusion import LDMBertConfig , LDMBertModel
5455from ..paint_by_example import PaintByExampleImageEncoder
5758from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
5859
5960
61+ if is_accelerate_available ():
62+ from accelerate import init_empty_weights
63+ from accelerate .utils import set_module_tensor_to_device
64+
6065logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
6166
6267
@@ -770,11 +775,12 @@ def _copy_layers(hf_layers, pt_layers):
770775
771776
772777def convert_ldm_clip_checkpoint (checkpoint , local_files_only = False , text_encoder = None ):
773- text_model = (
774- CLIPTextModel .from_pretrained ("openai/clip-vit-large-patch14" , local_files_only = local_files_only )
775- if text_encoder is None
776- else text_encoder
777- )
778+ if text_encoder is None :
779+ config_name = "openai/clip-vit-large-patch14"
780+ config = CLIPTextConfig .from_pretrained (config_name )
781+
782+ with init_empty_weights ():
783+ text_model = CLIPTextModel (config )
778784
779785 keys = list (checkpoint .keys ())
780786
@@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
787793 if key .startswith (prefix ):
788794 text_model_dict [key [len (prefix + "." ) :]] = checkpoint [key ]
789795
790- text_model .load_state_dict (text_model_dict )
796+ for param_name , param in text_model_dict .items ():
797+ set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
791798
792799 return text_model
793800
@@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint):
884891 return model
885892
886893
887- def convert_open_clip_checkpoint (checkpoint , prefix = "cond_stage_model.model." ):
894+ def convert_open_clip_checkpoint (
895+ checkpoint , config_name , prefix = "cond_stage_model.model." , has_projection = False , ** config_kwargs
896+ ):
888897 # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
889- text_model = CLIPTextModelWithProjection .from_pretrained (
890- "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" , projection_dim = 1280
891- )
898+ # text_model = CLIPTextModelWithProjection.from_pretrained(
899+ # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
900+ # )
901+ config = CLIPTextConfig .from_pretrained (config_name , ** config_kwargs )
902+
903+ with init_empty_weights ():
904+ text_model = CLIPTextModelWithProjection (config ) if has_projection else CLIPTextModel (config )
892905
893906 keys = list (checkpoint .keys ())
894907
908+ keys_to_ignore = []
909+ if config_name == "stabilityai/stable-diffusion-2" and config .num_hidden_layers == 23 :
910+ # make sure to remove all keys > 22
911+ keys_to_ignore += [k for k in keys if k .startswith ("cond_stage_model.model.transformer.resblocks.23" )]
912+ keys_to_ignore += ["cond_stage_model.model.text_projection" ]
913+
895914 text_model_dict = {}
896915
897916 if prefix + "text_projection" in checkpoint :
@@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
902921 text_model_dict ["text_model.embeddings.position_ids" ] = text_model .text_model .embeddings .get_buffer ("position_ids" )
903922
904923 for key in keys :
905- # if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
906- # continue
924+ if key in keys_to_ignore :
925+ continue
907926 if key [len (prefix ) :] in textenc_conversion_map :
908927 if key .endswith ("text_projection" ):
909928 value = checkpoint [key ].T
@@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
931950
932951 text_model_dict [new_key ] = checkpoint [key ]
933952
934- text_model .load_state_dict (text_model_dict )
953+ for param_name , param in text_model_dict .items ():
954+ set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
935955
936956 return text_model
937957
@@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint(
10611081def download_from_original_stable_diffusion_ckpt (
10621082 checkpoint_path : str ,
10631083 original_config_file : str = None ,
1064- image_size : int = 512 ,
1084+ image_size : Optional [ int ] = None ,
10651085 prediction_type : str = None ,
10661086 model_type : str = None ,
10671087 extract_ema : bool = False ,
@@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt(
11441164 LDMTextToImagePipeline ,
11451165 PaintByExamplePipeline ,
11461166 StableDiffusionControlNetPipeline ,
1167+ StableDiffusionInpaintPipeline ,
11471168 StableDiffusionPipeline ,
11481169 StableDiffusionXLImg2ImgPipeline ,
11491170 StableDiffusionXLPipeline ,
@@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
11661187 if not is_safetensors_available ():
11671188 raise ValueError (BACKENDS_MAPPING ["safetensors" ][1 ])
11681189
1169- from safetensors import safe_open
1190+ from safetensors . torch import load_file as safe_load
11701191
1171- checkpoint = {}
1172- with safe_open (checkpoint_path , framework = "pt" , device = "cpu" ) as f :
1173- for key in f .keys ():
1174- checkpoint [key ] = f .get_tensor (key )
1192+ checkpoint = safe_load (checkpoint_path , device = "cpu" )
11751193 else :
11761194 if device is None :
11771195 device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
11831201 if "global_step" in checkpoint :
11841202 global_step = checkpoint ["global_step" ]
11851203 else :
1186- logger .warning ("global_step key not found in model" )
1204+ logger .debug ("global_step key not found in model" )
11871205 global_step = None
11881206
11891207 # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
@@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt(
12301248 model_type = "SDXL"
12311249 else :
12321250 model_type = "SDXL-Refiner"
1251+ if image_size is None :
1252+ image_size = 1024
12331253
1234- if num_in_channels is not None :
1235- original_config ["model" ]["params" ]["unet_config" ]["params" ]["in_channels" ] = num_in_channels
1254+ if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline :
1255+ num_in_channels = 9
1256+ elif num_in_channels is None :
1257+ num_in_channels = 4
1258+
1259+ original_config ["model" ]["params" ]["unet_config" ]["params" ]["in_channels" ] = num_in_channels
12361260
12371261 if (
12381262 "parameterization" in original_config ["model" ]["params" ]
@@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt(
12631287 num_train_timesteps = getattr (original_config .model .params , "timesteps" , None ) or 1000
12641288
12651289 if model_type in ["SDXL" , "SDXL-Refiner" ]:
1266- image_size = 1024
12671290 scheduler_dict = {
12681291 "beta_schedule" : "scaled_linear" ,
12691292 "beta_start" : 0.00085 ,
@@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt(
12791302 }
12801303 scheduler = EulerDiscreteScheduler .from_config (scheduler_dict )
12811304 scheduler_type = "euler"
1282- vae_path = "stabilityai/sdxl-vae"
12831305 else :
12841306 beta_start = getattr (original_config .model .params , "linear_start" , None ) or 0.02
12851307 beta_end = getattr (original_config .model .params , "linear_end" , None ) or 0.085
@@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt(
13181340 # Convert the UNet2DConditionModel model.
13191341 unet_config = create_unet_diffusers_config (original_config , image_size = image_size )
13201342 unet_config ["upcast_attention" ] = upcast_attention
1321- unet = UNet2DConditionModel (** unet_config )
1343+ with init_empty_weights ():
1344+ unet = UNet2DConditionModel (** unet_config )
13221345
13231346 converted_unet_checkpoint = convert_ldm_unet_checkpoint (
13241347 checkpoint , unet_config , path = checkpoint_path , extract_ema = extract_ema
13251348 )
1326- unet .load_state_dict (converted_unet_checkpoint )
1349+
1350+ for param_name , param in converted_unet_checkpoint .items ():
1351+ set_module_tensor_to_device (unet , param_name , "cpu" , value = param )
13271352
13281353 # Convert the VAE model.
13291354 if vae_path is None :
13301355 vae_config = create_vae_diffusers_config (original_config , image_size = image_size )
13311356 converted_vae_checkpoint = convert_ldm_vae_checkpoint (checkpoint , vae_config )
13321357
1333- vae = AutoencoderKL (** vae_config )
1334- vae .load_state_dict (converted_vae_checkpoint )
1358+ if (
1359+ "model" in original_config
1360+ and "params" in original_config .model
1361+ and "scale_factor" in original_config .model .params
1362+ ):
1363+ vae_scaling_factor = original_config .model .params .scale_factor
1364+ else :
1365+ vae_scaling_factor = 0.18215 # default SD scaling factor
1366+
1367+ vae_config ["scaling_factor" ] = vae_scaling_factor
1368+
1369+ with init_empty_weights ():
1370+ vae = AutoencoderKL (** vae_config )
1371+
1372+ for param_name , param in converted_vae_checkpoint .items ():
1373+ set_module_tensor_to_device (vae , param_name , "cpu" , value = param )
13351374 else :
13361375 vae = AutoencoderKL .from_pretrained (vae_path )
13371376
13381377 if model_type == "FrozenOpenCLIPEmbedder" :
1339- text_model = convert_open_clip_checkpoint (checkpoint )
1378+ config_name = "stabilityai/stable-diffusion-2"
1379+ config_kwargs = {"subfolder" : "text_encoder" }
1380+
1381+ text_model = convert_open_clip_checkpoint (checkpoint , config_name , ** config_kwargs )
13401382 tokenizer = CLIPTokenizer .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "tokenizer" )
13411383
13421384 if stable_unclip is None :
@@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt(
14691511 tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
14701512 text_encoder = convert_ldm_clip_checkpoint (checkpoint , local_files_only = local_files_only )
14711513 tokenizer_2 = CLIPTokenizer .from_pretrained ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" , pad_token = "!" )
1472- text_encoder_2 = convert_open_clip_checkpoint (checkpoint , prefix = "conditioner.embedders.1.model." )
1514+
1515+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1516+ config_kwargs = {"projection_dim" : 1280 }
1517+ text_encoder_2 = convert_open_clip_checkpoint (
1518+ checkpoint , config_name , prefix = "conditioner.embedders.1.model." , has_projection = True , ** config_kwargs
1519+ )
14731520
14741521 pipe = StableDiffusionXLPipeline (
14751522 vae = vae ,
@@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt(
14851532 tokenizer = None
14861533 text_encoder = None
14871534 tokenizer_2 = CLIPTokenizer .from_pretrained ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" , pad_token = "!" )
1488- text_encoder_2 = convert_open_clip_checkpoint (checkpoint , prefix = "conditioner.embedders.0.model." )
1535+
1536+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1537+ config_kwargs = {"projection_dim" : 1280 }
1538+ text_encoder_2 = convert_open_clip_checkpoint (
1539+ checkpoint , config_name , prefix = "conditioner.embedders.0.model." , has_projection = True , ** config_kwargs
1540+ )
14891541
14901542 pipe = StableDiffusionXLImg2ImgPipeline (
14911543 vae = vae ,
0 commit comments