Skip to content

Commit 4909b1e

Browse files
authored
[Examples] fix checkpointing and casting bugs in train_text_to_image_lora_sdxl.py (huggingface#4632)
* fix: casting issues. * fix checkpointing. * tests * fix: bugs
1 parent 052bf32 commit 4909b1e

File tree

2 files changed

+92
-21
lines changed

2 files changed

+92
-21
lines changed

examples/test_examples.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
828828
{"checkpoint-4", "checkpoint-6"},
829829
)
830830

831+
def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
832+
prompt = "a prompt"
833+
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
834+
835+
with tempfile.TemporaryDirectory() as tmpdir:
836+
# Run training script with checkpointing
837+
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
838+
# Should create checkpoints at steps 2, 4, 6
839+
# with checkpoint at step 2 deleted
840+
841+
initial_run_args = f"""
842+
examples/text_to_image/train_text_to_image_lora_sdxl.py
843+
--pretrained_model_name_or_path {pipeline_path}
844+
--dataset_name hf-internal-testing/dummy_image_text_data
845+
--resolution 64
846+
--train_batch_size 1
847+
--gradient_accumulation_steps 1
848+
--max_train_steps 7
849+
--learning_rate 5.0e-04
850+
--scale_lr
851+
--lr_scheduler constant
852+
--lr_warmup_steps 0
853+
--output_dir {tmpdir}
854+
--checkpointing_steps=2
855+
--checkpoints_total_limit=2
856+
""".split()
857+
858+
run_command(self._launch_args + initial_run_args)
859+
860+
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
861+
pipe.load_lora_weights(tmpdir)
862+
pipe(prompt, num_inference_steps=2)
863+
864+
# check checkpoint directories exist
865+
self.assertEqual(
866+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
867+
# checkpoint-2 should have been deleted
868+
{"checkpoint-4", "checkpoint-6"},
869+
)
870+
871+
def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
872+
prompt = "a prompt"
873+
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
874+
875+
with tempfile.TemporaryDirectory() as tmpdir:
876+
# Run training script with checkpointing
877+
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
878+
# Should create checkpoints at steps 2, 4, 6
879+
# with checkpoint at step 2 deleted
880+
881+
initial_run_args = f"""
882+
examples/text_to_image/train_text_to_image_lora_sdxl.py
883+
--pretrained_model_name_or_path {pipeline_path}
884+
--dataset_name hf-internal-testing/dummy_image_text_data
885+
--resolution 64
886+
--train_batch_size 1
887+
--gradient_accumulation_steps 1
888+
--max_train_steps 7
889+
--learning_rate 5.0e-04
890+
--scale_lr
891+
--lr_scheduler constant
892+
--train_text_encoder
893+
--lr_warmup_steps 0
894+
--output_dir {tmpdir}
895+
--checkpointing_steps=2
896+
--checkpoints_total_limit=2
897+
""".split()
898+
899+
run_command(self._launch_args + initial_run_args)
900+
901+
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
902+
pipe.load_lora_weights(tmpdir)
903+
pipe(prompt, num_inference_steps=2)
904+
905+
# check checkpoint directories exist
906+
self.assertEqual(
907+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
908+
# checkpoint-2 should have been deleted
909+
{"checkpoint-4", "checkpoint-6"},
910+
)
911+
831912
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
832913
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
833914
prompt = "a prompt"

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -396,16 +396,6 @@ def parse_args(input_args=None):
396396
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
397397
),
398398
)
399-
parser.add_argument(
400-
"--prior_generation_precision",
401-
type=str,
402-
default=None,
403-
choices=["no", "fp32", "fp16", "bf16"],
404-
help=(
405-
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
406-
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
407-
),
408-
)
409399
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
410400
parser.add_argument(
411401
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
@@ -724,11 +714,15 @@ def load_model_hook(models, input_dir):
724714

725715
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
726716
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
717+
718+
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
727719
LoraLoaderMixin.load_lora_into_text_encoder(
728-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
720+
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
729721
)
722+
723+
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
730724
LoraLoaderMixin.load_lora_into_text_encoder(
731-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
725+
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
732726
)
733727

734728
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1002,9 +996,12 @@ def collate_fn(examples):
1002996
continue
1003997

1004998
with accelerator.accumulate(unet):
1005-
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1006-
1007999
# Convert images to latent space
1000+
if args.pretrained_vae_model_name_or_path is not None:
1001+
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1002+
else:
1003+
pixel_values = batch["pixel_values"]
1004+
10081005
model_input = vae.encode(pixel_values).latent_dist.sample()
10091006
model_input = model_input * vae.config.scaling_factor
10101007
if args.pretrained_vae_model_name_or_path is None:
@@ -1147,13 +1144,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
11471144
f" {args.validation_prompt}."
11481145
)
11491146
# create pipeline
1150-
if not args.train_text_encoder:
1151-
text_encoder_one = text_encoder_cls_one.from_pretrained(
1152-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
1153-
)
1154-
text_encoder_two = text_encoder_cls_two.from_pretrained(
1155-
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
1156-
)
11571147
pipeline = StableDiffusionXLPipeline.from_pretrained(
11581148
args.pretrained_model_name_or_path,
11591149
vae=vae,

0 commit comments

Comments
 (0)