Skip to content

Commit e65b71a

Browse files
authored
Add an explicit --image_size to the conversion script (huggingface#1509)
* Add an explicit `--image_size` to the conversion script * style
1 parent a6a25ce commit e65b71a

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,12 @@ def conv_attn_to_linear(checkpoint):
207207
checkpoint[key] = checkpoint[key][:, :, 0]
208208

209209

210-
def create_unet_diffusers_config(original_config):
210+
def create_unet_diffusers_config(original_config, image_size: int):
211211
"""
212212
Creates a config for the diffusers based on the config of the LDM model.
213213
"""
214-
model_params = original_config.model.params
215214
unet_params = original_config.model.params.unet_config.params
215+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
216216

217217
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
218218

@@ -230,8 +230,10 @@ def create_unet_diffusers_config(original_config):
230230
up_block_types.append(block_type)
231231
resolution //= 2
232232

233+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
234+
233235
config = dict(
234-
sample_size=model_params.image_size,
236+
sample_size=image_size // vae_scale_factor,
235237
in_channels=unet_params.in_channels,
236238
out_channels=unet_params.out_channels,
237239
down_block_types=tuple(down_block_types),
@@ -245,7 +247,7 @@ def create_unet_diffusers_config(original_config):
245247
return config
246248

247249

248-
def create_vae_diffusers_config(original_config):
250+
def create_vae_diffusers_config(original_config, image_size: int):
249251
"""
250252
Creates a config for the diffusers based on the config of the LDM model.
251253
"""
@@ -257,7 +259,7 @@ def create_vae_diffusers_config(original_config):
257259
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
258260

259261
config = dict(
260-
sample_size=vae_params.resolution,
262+
sample_size=image_size,
261263
in_channels=vae_params.in_channels,
262264
out_channels=vae_params.out_ch,
263265
down_block_types=tuple(down_block_types),
@@ -653,6 +655,15 @@ def convert_ldm_clip_checkpoint(checkpoint):
653655
type=str,
654656
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
655657
)
658+
parser.add_argument(
659+
"--image_size",
660+
default=512,
661+
type=int,
662+
help=(
663+
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
664+
" Base. Use 768 for Stable Diffusion v2."
665+
),
666+
)
656667
parser.add_argument(
657668
"--extract_ema",
658669
action="store_true",
@@ -712,7 +723,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
712723
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
713724

714725
# Convert the UNet2DConditionModel model.
715-
unet_config = create_unet_diffusers_config(original_config)
726+
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
716727
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
717728
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
718729
)
@@ -721,7 +732,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
721732
unet.load_state_dict(converted_unet_checkpoint)
722733

723734
# Convert the VAE model.
724-
vae_config = create_vae_diffusers_config(original_config)
735+
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
725736
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
726737

727738
vae = AutoencoderKL(**vae_config)

0 commit comments

Comments
 (0)