@@ -828,6 +828,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
828828 {"checkpoint-4" , "checkpoint-6" },
829829 )
830830
831+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit (self ):
832+ prompt = "a prompt"
833+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
834+
835+ with tempfile .TemporaryDirectory () as tmpdir :
836+ # Run training script with checkpointing
837+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
838+ # Should create checkpoints at steps 2, 4, 6
839+ # with checkpoint at step 2 deleted
840+
841+ initial_run_args = f"""
842+ examples/text_to_image/train_text_to_image_lora_sdxl.py
843+ --pretrained_model_name_or_path { pipeline_path }
844+ --dataset_name hf-internal-testing/dummy_image_text_data
845+ --resolution 64
846+ --train_batch_size 1
847+ --gradient_accumulation_steps 1
848+ --max_train_steps 7
849+ --learning_rate 5.0e-04
850+ --scale_lr
851+ --lr_scheduler constant
852+ --lr_warmup_steps 0
853+ --output_dir { tmpdir }
854+ --checkpointing_steps=2
855+ --checkpoints_total_limit=2
856+ """ .split ()
857+
858+ run_command (self ._launch_args + initial_run_args )
859+
860+ pipe = DiffusionPipeline .from_pretrained (pipeline_path )
861+ pipe .load_lora_weights (tmpdir )
862+ pipe (prompt , num_inference_steps = 2 )
863+
864+ # check checkpoint directories exist
865+ self .assertEqual (
866+ {x for x in os .listdir (tmpdir ) if "checkpoint" in x },
867+ # checkpoint-2 should have been deleted
868+ {"checkpoint-4" , "checkpoint-6" },
869+ )
870+
871+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit (self ):
872+ prompt = "a prompt"
873+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
874+
875+ with tempfile .TemporaryDirectory () as tmpdir :
876+ # Run training script with checkpointing
877+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
878+ # Should create checkpoints at steps 2, 4, 6
879+ # with checkpoint at step 2 deleted
880+
881+ initial_run_args = f"""
882+ examples/text_to_image/train_text_to_image_lora_sdxl.py
883+ --pretrained_model_name_or_path { pipeline_path }
884+ --dataset_name hf-internal-testing/dummy_image_text_data
885+ --resolution 64
886+ --train_batch_size 1
887+ --gradient_accumulation_steps 1
888+ --max_train_steps 7
889+ --learning_rate 5.0e-04
890+ --scale_lr
891+ --lr_scheduler constant
892+ --train_text_encoder
893+ --lr_warmup_steps 0
894+ --output_dir { tmpdir }
895+ --checkpointing_steps=2
896+ --checkpoints_total_limit=2
897+ """ .split ()
898+
899+ run_command (self ._launch_args + initial_run_args )
900+
901+ pipe = DiffusionPipeline .from_pretrained (pipeline_path )
902+ pipe .load_lora_weights (tmpdir )
903+ pipe (prompt , num_inference_steps = 2 )
904+
905+ # check checkpoint directories exist
906+ self .assertEqual (
907+ {x for x in os .listdir (tmpdir ) if "checkpoint" in x },
908+ # checkpoint-2 should have been deleted
909+ {"checkpoint-4" , "checkpoint-6" },
910+ )
911+
831912 def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints (self ):
832913 pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
833914 prompt = "a prompt"
0 commit comments