1515""" Conversion script for the Stable Diffusion checkpoints."""
1616
1717import re
18+ from contextlib import nullcontext
1819from io import BytesIO
1920from typing import Optional
2021
@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
779780 config_name = "openai/clip-vit-large-patch14"
780781 config = CLIPTextConfig .from_pretrained (config_name )
781782
782- with init_empty_weights ():
783+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
784+ with ctx ():
783785 text_model = CLIPTextModel (config )
784786
785787 keys = list (checkpoint .keys ())
@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
793795 if key .startswith (prefix ):
794796 text_model_dict [key [len (prefix + "." ) :]] = checkpoint [key ]
795797
796- for param_name , param in text_model_dict .items ():
797- set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
798+ if is_accelerate_available ():
799+ for param_name , param in text_model_dict .items ():
800+ set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
801+ else :
802+ text_model .load_state_dict (text_model_dict )
798803
799804 return text_model
800805
@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
900905 # )
901906 config = CLIPTextConfig .from_pretrained (config_name , ** config_kwargs )
902907
903- with init_empty_weights ():
908+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
909+ with ctx ():
904910 text_model = CLIPTextModelWithProjection (config ) if has_projection else CLIPTextModel (config )
905911
906912 keys = list (checkpoint .keys ())
@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
950956
951957 text_model_dict [new_key ] = checkpoint [key ]
952958
953- for param_name , param in text_model_dict .items ():
954- set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
959+ if is_accelerate_available ():
960+ for param_name , param in text_model_dict .items ():
961+ set_module_tensor_to_device (text_model , param_name , "cpu" , value = param )
962+ else :
963+ text_model .load_state_dict (text_model_dict )
955964
956965 return text_model
957966
@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
11721181 StableUnCLIPPipeline ,
11731182 )
11741183
1175- if not is_accelerate_available ():
1176- raise ImportError (
1177- "To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`."
1178- )
1179-
11801184 if pipeline_class is None :
11811185 pipeline_class = StableDiffusionPipeline
11821186
@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
13461350 # Convert the UNet2DConditionModel model.
13471351 unet_config = create_unet_diffusers_config (original_config , image_size = image_size )
13481352 unet_config ["upcast_attention" ] = upcast_attention
1349- with init_empty_weights ():
1350- unet = UNet2DConditionModel (** unet_config )
1351-
13521353 converted_unet_checkpoint = convert_ldm_unet_checkpoint (
13531354 checkpoint , unet_config , path = checkpoint_path , extract_ema = extract_ema
13541355 )
13551356
1356- for param_name , param in converted_unet_checkpoint .items ():
1357- set_module_tensor_to_device (unet , param_name , "cpu" , value = param )
1357+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
1358+ with ctx ():
1359+ unet = UNet2DConditionModel (** unet_config )
1360+
1361+ if is_accelerate_available ():
1362+ for param_name , param in converted_unet_checkpoint .items ():
1363+ set_module_tensor_to_device (unet , param_name , "cpu" , value = param )
1364+ else :
1365+ unet .load_state_dict (converted_unet_checkpoint )
13581366
13591367 # Convert the VAE model.
13601368 if vae_path is None :
@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
13721380
13731381 vae_config ["scaling_factor" ] = vae_scaling_factor
13741382
1375- with init_empty_weights ():
1383+ ctx = init_empty_weights if is_accelerate_available () else nullcontext
1384+ with ctx ():
13761385 vae = AutoencoderKL (** vae_config )
13771386
1378- for param_name , param in converted_vae_checkpoint .items ():
1379- set_module_tensor_to_device (vae , param_name , "cpu" , value = param )
1387+ if is_accelerate_available ():
1388+ for param_name , param in converted_vae_checkpoint .items ():
1389+ set_module_tensor_to_device (vae , param_name , "cpu" , value = param )
1390+ else :
1391+ vae .load_state_dict (converted_vae_checkpoint )
13801392 else :
13811393 vae = AutoencoderKL .from_pretrained (vae_path )
13821394
0 commit comments