Skip to content

Commit 4447547

Browse files
authored
[Examples] fix sdxl dreambooth lora checkpointing. (huggingface#4749)
* fix sdxl dreambooth lora checkpointing. * style
1 parent 5222294 commit 4447547

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,11 +843,15 @@ def load_model_hook(models, input_dir):
843843

844844
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
845845
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
846+
847+
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
846848
LoraLoaderMixin.load_lora_into_text_encoder(
847-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
849+
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
848850
)
851+
852+
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
849853
LoraLoaderMixin.load_lora_into_text_encoder(
850-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
854+
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
851855
)
852856

853857
accelerator.register_save_state_pre_hook(save_model_hook)

examples/test_examples.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,77 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):
421421
)
422422
self.assertTrue(starts_with_unet)
423423

424+
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
425+
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
426+
427+
with tempfile.TemporaryDirectory() as tmpdir:
428+
test_args = f"""
429+
examples/dreambooth/train_dreambooth_lora_sdxl.py
430+
--pretrained_model_name_or_path {pipeline_path}
431+
--instance_data_dir docs/source/en/imgs
432+
--instance_prompt photo
433+
--resolution 64
434+
--train_batch_size 1
435+
--gradient_accumulation_steps 1
436+
--max_train_steps 7
437+
--checkpointing_steps=2
438+
--checkpoints_total_limit=2
439+
--learning_rate 5.0e-04
440+
--scale_lr
441+
--lr_scheduler constant
442+
--lr_warmup_steps 0
443+
--output_dir {tmpdir}
444+
""".split()
445+
446+
run_command(self._launch_args + test_args)
447+
448+
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
449+
pipe.load_lora_weights(tmpdir)
450+
pipe("a prompt", num_inference_steps=2)
451+
452+
# check checkpoint directories exist
453+
self.assertEqual(
454+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
455+
# checkpoint-2 should have been deleted
456+
{"checkpoint-4", "checkpoint-6"},
457+
)
458+
459+
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
460+
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
461+
462+
with tempfile.TemporaryDirectory() as tmpdir:
463+
test_args = f"""
464+
examples/dreambooth/train_dreambooth_lora_sdxl.py
465+
--pretrained_model_name_or_path {pipeline_path}
466+
--instance_data_dir docs/source/en/imgs
467+
--instance_prompt photo
468+
--resolution 64
469+
--train_batch_size 1
470+
--gradient_accumulation_steps 1
471+
--max_train_steps 7
472+
--checkpointing_steps=2
473+
--checkpoints_total_limit=2
474+
--train_text_encoder
475+
--learning_rate 5.0e-04
476+
--scale_lr
477+
--lr_scheduler constant
478+
--lr_warmup_steps 0
479+
--output_dir {tmpdir}
480+
""".split()
481+
482+
run_command(self._launch_args + test_args)
483+
484+
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
485+
pipe.load_lora_weights(tmpdir)
486+
pipe("a prompt", num_inference_steps=2)
487+
488+
# check checkpoint directories exist
489+
self.assertEqual(
490+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
491+
# checkpoint-2 should have been deleted
492+
{"checkpoint-4", "checkpoint-6"},
493+
)
494+
424495
def test_custom_diffusion(self):
425496
with tempfile.TemporaryDirectory() as tmpdir:
426497
test_args = f"""

0 commit comments

Comments
 (0)