Skip to content

Commit f3d1333

Browse files
authored
Improve LCM(-LoRA) Distillation Scripts (huggingface#6420)
* Make WDS pipeline interpolation type configurable. * Make the VAE encoding batch size configurable. * Make lora_alpha and lora_dropout configurable for LCM LoRA scripts. * Generalize scalings_for_boundary_conditions function and make the timestep scaling configurable. * Make LoRA target modules configurable for LCM-LoRA scripts. * Move resolve_interpolation_mode to src/diffusers/training_utils.py and make interpolation type configurable in non-WDS script. * apply suggestions from review
1 parent acd926f commit f3d1333

File tree

6 files changed

+374
-53
lines changed

6 files changed

+374
-53
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
UNet2DConditionModel,
6262
)
6363
from diffusers.optimization import get_scheduler
64+
from diffusers.training_utils import resolve_interpolation_mode
6465
from diffusers.utils import check_min_version, is_wandb_available
6566
from diffusers.utils.import_utils import is_xformers_available
6667

@@ -165,6 +166,7 @@ def __init__(
165166
global_batch_size: int,
166167
num_workers: int,
167168
resolution: int = 512,
169+
interpolation_type: str = "bilinear",
168170
shuffle_buffer_size: int = 1000,
169171
pin_memory: bool = False,
170172
persistent_workers: bool = False,
@@ -174,10 +176,12 @@ def __init__(
174176
# flatten list using itertools
175177
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
176178

179+
interpolation_mode = resolve_interpolation_mode(interpolation_type)
180+
177181
def transform(example):
178182
# resize image
179183
image = example["image"]
180-
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
184+
image = TF.resize(image, resolution, interpolation=interpolation_mode)
181185

182186
# get crop coordinates and crop image
183187
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
@@ -353,8 +357,9 @@ def append_dims(x, target_dims):
353357

354358
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
355359
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
356-
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
357-
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
360+
scaled_timestep = timestep_scaling * timestep
361+
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
362+
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
358363
return c_skip, c_out
359364

360365

@@ -572,6 +577,15 @@ def parse_args():
572577
" resolution"
573578
),
574579
)
580+
parser.add_argument(
581+
"--interpolation_type",
582+
type=str,
583+
default="bilinear",
584+
help=(
585+
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
586+
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
587+
),
588+
)
575589
parser.add_argument(
576590
"--center_crop",
577591
default=False,
@@ -710,6 +724,50 @@ def parse_args():
710724
default=64,
711725
help="The rank of the LoRA projection matrix.",
712726
)
727+
parser.add_argument(
728+
"--lora_alpha",
729+
type=int,
730+
default=64,
731+
help=(
732+
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
733+
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
734+
),
735+
)
736+
parser.add_argument(
737+
"--lora_dropout",
738+
type=float,
739+
default=0.0,
740+
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
741+
)
742+
parser.add_argument(
743+
"--lora_target_modules",
744+
type=str,
745+
default=None,
746+
help=(
747+
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
748+
" be used. By default, LoRA will be applied to all conv and linear layers."
749+
),
750+
)
751+
parser.add_argument(
752+
"--vae_encode_batch_size",
753+
type=int,
754+
default=32,
755+
required=False,
756+
help=(
757+
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
758+
" Encoding or decoding the whole batch at once may run into OOM issues."
759+
),
760+
)
761+
parser.add_argument(
762+
"--timestep_scaling_factor",
763+
type=float,
764+
default=10.0,
765+
help=(
766+
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
767+
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
768+
" suffice."
769+
),
770+
)
713771
# ----Mixed Precision----
714772
parser.add_argument(
715773
"--mixed_precision",
@@ -915,9 +973,10 @@ def main(args):
915973
)
916974

917975
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
918-
lora_config = LoraConfig(
919-
r=args.lora_rank,
920-
target_modules=[
976+
if args.lora_target_modules is not None:
977+
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
978+
else:
979+
lora_target_modules = [
921980
"to_q",
922981
"to_k",
923982
"to_v",
@@ -932,7 +991,12 @@ def main(args):
932991
"downsamplers.0.conv",
933992
"upsamplers.0.conv",
934993
"time_emb_proj",
935-
],
994+
]
995+
lora_config = LoraConfig(
996+
r=args.lora_rank,
997+
target_modules=lora_target_modules,
998+
lora_alpha=args.lora_alpha,
999+
lora_dropout=args.lora_dropout,
9361000
)
9371001
unet = get_peft_model(unet, lora_config)
9381002

@@ -1051,6 +1115,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
10511115
global_batch_size=args.train_batch_size * accelerator.num_processes,
10521116
num_workers=args.dataloader_num_workers,
10531117
resolution=args.resolution,
1118+
interpolation_type=args.interpolation_type,
10541119
shuffle_buffer_size=1000,
10551120
pin_memory=True,
10561121
persistent_workers=True,
@@ -1162,10 +1227,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11621227
if vae.dtype != weight_dtype:
11631228
vae.to(dtype=weight_dtype)
11641229

1165-
# encode pixel values with batch size of at most 32
1230+
# encode pixel values with batch size of at most args.vae_encode_batch_size
11661231
latents = []
1167-
for i in range(0, pixel_values.shape[0], 32):
1168-
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
1232+
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
1233+
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
11691234
latents = torch.cat(latents, dim=0)
11701235

11711236
latents = latents * vae.config.scaling_factor
@@ -1181,9 +1246,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11811246
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
11821247

11831248
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
1184-
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
1249+
c_skip_start, c_out_start = scalings_for_boundary_conditions(
1250+
start_timesteps, timestep_scaling=args.timestep_scaling_factor
1251+
)
11851252
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
1186-
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
1253+
c_skip, c_out = scalings_for_boundary_conditions(
1254+
timesteps, timestep_scaling=args.timestep_scaling_factor
1255+
)
11871256
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
11881257

11891258
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
UNet2DConditionModel,
5252
)
5353
from diffusers.optimization import get_scheduler
54+
from diffusers.training_utils import resolve_interpolation_mode
5455
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5556
from diffusers.utils.import_utils import is_xformers_available
5657

@@ -193,8 +194,9 @@ def append_dims(x, target_dims):
193194

194195
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
195196
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
196-
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
197-
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
197+
scaled_timestep = timestep_scaling * timestep
198+
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
199+
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
198200
return c_skip, c_out
199201

200202

@@ -396,6 +398,15 @@ def parse_args():
396398
" resolution"
397399
),
398400
)
401+
parser.add_argument(
402+
"--interpolation_type",
403+
type=str,
404+
default="bilinear",
405+
help=(
406+
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
407+
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
408+
),
409+
)
399410
parser.add_argument(
400411
"--center_crop",
401412
default=False,
@@ -534,6 +545,50 @@ def parse_args():
534545
default=64,
535546
help="The rank of the LoRA projection matrix.",
536547
)
548+
parser.add_argument(
549+
"--lora_alpha",
550+
type=int,
551+
default=64,
552+
help=(
553+
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
554+
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
555+
),
556+
)
557+
parser.add_argument(
558+
"--lora_dropout",
559+
type=float,
560+
default=0.0,
561+
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
562+
)
563+
parser.add_argument(
564+
"--lora_target_modules",
565+
type=str,
566+
default=None,
567+
help=(
568+
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
569+
" be used. By default, LoRA will be applied to all conv and linear layers."
570+
),
571+
)
572+
parser.add_argument(
573+
"--vae_encode_batch_size",
574+
type=int,
575+
default=8,
576+
required=False,
577+
help=(
578+
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
579+
" Encoding or decoding the whole batch at once may run into OOM issues."
580+
),
581+
)
582+
parser.add_argument(
583+
"--timestep_scaling_factor",
584+
type=float,
585+
default=10.0,
586+
help=(
587+
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
588+
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
589+
" suffice."
590+
),
591+
)
537592
# ----Mixed Precision----
538593
parser.add_argument(
539594
"--mixed_precision",
@@ -776,10 +831,10 @@ def main(args):
776831
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
777832

778833
# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
779-
lora_config = LoraConfig(
780-
r=args.lora_rank,
781-
lora_alpha=args.lora_rank,
782-
target_modules=[
834+
if args.lora_target_modules is not None:
835+
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
836+
else:
837+
lora_target_modules = [
783838
"to_q",
784839
"to_k",
785840
"to_v",
@@ -794,7 +849,12 @@ def main(args):
794849
"downsamplers.0.conv",
795850
"upsamplers.0.conv",
796851
"time_emb_proj",
797-
],
852+
]
853+
lora_config = LoraConfig(
854+
r=args.lora_rank,
855+
target_modules=lora_target_modules,
856+
lora_alpha=args.lora_alpha,
857+
lora_dropout=args.lora_dropout,
798858
)
799859
unet.add_adapter(lora_config)
800860

@@ -929,7 +989,8 @@ def load_model_hook(models, input_dir):
929989
)
930990

931991
# Preprocessing the datasets.
932-
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
992+
interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
993+
train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)
933994
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
934995
train_flip = transforms.RandomHorizontalFlip(p=1.0)
935996
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
@@ -1121,11 +1182,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
11211182

11221183
encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
11231184

1124-
# encode pixel values with batch size of at most 8
1185+
# encode pixel values with batch size of at most args.vae_encode_batch_size
11251186
pixel_values = pixel_values.to(dtype=vae.dtype)
11261187
latents = []
1127-
for i in range(0, pixel_values.shape[0], args.encode_batch_size):
1128-
latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample())
1188+
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
1189+
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
11291190
latents = torch.cat(latents, dim=0)
11301191

11311192
latents = latents * vae.config.scaling_factor
@@ -1142,9 +1203,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
11421203
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
11431204

11441205
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
1145-
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
1206+
c_skip_start, c_out_start = scalings_for_boundary_conditions(
1207+
start_timesteps, timestep_scaling=args.timestep_scaling_factor
1208+
)
11461209
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
1147-
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
1210+
c_skip, c_out = scalings_for_boundary_conditions(
1211+
timesteps, timestep_scaling=args.timestep_scaling_factor
1212+
)
11481213
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
11491214

11501215
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each

0 commit comments

Comments
 (0)