@@ -207,12 +207,12 @@ def conv_attn_to_linear(checkpoint):
207207 checkpoint [key ] = checkpoint [key ][:, :, 0 ]
208208
209209
210- def create_unet_diffusers_config (original_config ):
210+ def create_unet_diffusers_config (original_config , image_size : int ):
211211 """
212212 Creates a config for the diffusers based on the config of the LDM model.
213213 """
214- model_params = original_config .model .params
215214 unet_params = original_config .model .params .unet_config .params
215+ vae_params = original_config .model .params .first_stage_config .params .ddconfig
216216
217217 block_out_channels = [unet_params .model_channels * mult for mult in unet_params .channel_mult ]
218218
@@ -230,8 +230,10 @@ def create_unet_diffusers_config(original_config):
230230 up_block_types .append (block_type )
231231 resolution //= 2
232232
233+ vae_scale_factor = 2 ** (len (vae_params .ch_mult ) - 1 )
234+
233235 config = dict (
234- sample_size = model_params . image_size ,
236+ sample_size = image_size // vae_scale_factor ,
235237 in_channels = unet_params .in_channels ,
236238 out_channels = unet_params .out_channels ,
237239 down_block_types = tuple (down_block_types ),
@@ -245,7 +247,7 @@ def create_unet_diffusers_config(original_config):
245247 return config
246248
247249
248- def create_vae_diffusers_config (original_config ):
250+ def create_vae_diffusers_config (original_config , image_size : int ):
249251 """
250252 Creates a config for the diffusers based on the config of the LDM model.
251253 """
@@ -257,7 +259,7 @@ def create_vae_diffusers_config(original_config):
257259 up_block_types = ["UpDecoderBlock2D" ] * len (block_out_channels )
258260
259261 config = dict (
260- sample_size = vae_params . resolution ,
262+ sample_size = image_size ,
261263 in_channels = vae_params .in_channels ,
262264 out_channels = vae_params .out_ch ,
263265 down_block_types = tuple (down_block_types ),
@@ -653,6 +655,15 @@ def convert_ldm_clip_checkpoint(checkpoint):
653655 type = str ,
654656 help = "Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']" ,
655657 )
658+ parser .add_argument (
659+ "--image_size" ,
660+ default = 512 ,
661+ type = int ,
662+ help = (
663+ "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
664+ " Base. Use 768 for Stable Diffusion v2."
665+ ),
666+ )
656667 parser .add_argument (
657668 "--extract_ema" ,
658669 action = "store_true" ,
@@ -712,7 +723,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
712723 raise ValueError (f"Scheduler of type { args .scheduler_type } doesn't exist!" )
713724
714725 # Convert the UNet2DConditionModel model.
715- unet_config = create_unet_diffusers_config (original_config )
726+ unet_config = create_unet_diffusers_config (original_config , image_size = args . image_size )
716727 converted_unet_checkpoint = convert_ldm_unet_checkpoint (
717728 checkpoint , unet_config , path = args .checkpoint_path , extract_ema = args .extract_ema
718729 )
@@ -721,7 +732,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
721732 unet .load_state_dict (converted_unet_checkpoint )
722733
723734 # Convert the VAE model.
724- vae_config = create_vae_diffusers_config (original_config )
735+ vae_config = create_vae_diffusers_config (original_config , image_size = args . image_size )
725736 converted_vae_checkpoint = convert_ldm_vae_checkpoint (checkpoint , vae_config )
726737
727738 vae = AutoencoderKL (** vae_config )
0 commit comments