@@ -421,6 +421,77 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):
421421 )
422422 self .assertTrue (starts_with_unet )
423423
424+ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit (self ):
425+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
426+
427+ with tempfile .TemporaryDirectory () as tmpdir :
428+ test_args = f"""
429+ examples/dreambooth/train_dreambooth_lora_sdxl.py
430+ --pretrained_model_name_or_path { pipeline_path }
431+ --instance_data_dir docs/source/en/imgs
432+ --instance_prompt photo
433+ --resolution 64
434+ --train_batch_size 1
435+ --gradient_accumulation_steps 1
436+ --max_train_steps 7
437+ --checkpointing_steps=2
438+ --checkpoints_total_limit=2
439+ --learning_rate 5.0e-04
440+ --scale_lr
441+ --lr_scheduler constant
442+ --lr_warmup_steps 0
443+ --output_dir { tmpdir }
444+ """ .split ()
445+
446+ run_command (self ._launch_args + test_args )
447+
448+ pipe = DiffusionPipeline .from_pretrained (pipeline_path )
449+ pipe .load_lora_weights (tmpdir )
450+ pipe ("a prompt" , num_inference_steps = 2 )
451+
452+ # check checkpoint directories exist
453+ self .assertEqual (
454+ {x for x in os .listdir (tmpdir ) if "checkpoint" in x },
455+ # checkpoint-2 should have been deleted
456+ {"checkpoint-4" , "checkpoint-6" },
457+ )
458+
459+ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit (self ):
460+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
461+
462+ with tempfile .TemporaryDirectory () as tmpdir :
463+ test_args = f"""
464+ examples/dreambooth/train_dreambooth_lora_sdxl.py
465+ --pretrained_model_name_or_path { pipeline_path }
466+ --instance_data_dir docs/source/en/imgs
467+ --instance_prompt photo
468+ --resolution 64
469+ --train_batch_size 1
470+ --gradient_accumulation_steps 1
471+ --max_train_steps 7
472+ --checkpointing_steps=2
473+ --checkpoints_total_limit=2
474+ --train_text_encoder
475+ --learning_rate 5.0e-04
476+ --scale_lr
477+ --lr_scheduler constant
478+ --lr_warmup_steps 0
479+ --output_dir { tmpdir }
480+ """ .split ()
481+
482+ run_command (self ._launch_args + test_args )
483+
484+ pipe = DiffusionPipeline .from_pretrained (pipeline_path )
485+ pipe .load_lora_weights (tmpdir )
486+ pipe ("a prompt" , num_inference_steps = 2 )
487+
488+ # check checkpoint directories exist
489+ self .assertEqual (
490+ {x for x in os .listdir (tmpdir ) if "checkpoint" in x },
491+ # checkpoint-2 should have been deleted
492+ {"checkpoint-4" , "checkpoint-6" },
493+ )
494+
424495 def test_custom_diffusion (self ):
425496 with tempfile .TemporaryDirectory () as tmpdir :
426497 test_args = f"""
0 commit comments