-
Notifications
You must be signed in to change notification settings - Fork 6.1k
manual check for checkpoints_total_limit instead of using accelerate #3681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
manual check for checkpoints_total_limit instead of using accelerate #3681
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
This came off cleaner than #3652, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perf! Can we maybe apply this change to all other training scripts as well?
Yep! |
e903b71
to
99d53c2
Compare
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") | ||
if tracker.name == "wandb": | ||
tracker.log( | ||
{ | ||
"test": [ | ||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||
for i, image in enumerate(images) | ||
] | ||
} | ||
) | ||
if len(images) != 0: | ||
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") | ||
if tracker.name == "wandb": | ||
tracker.log( | ||
{ | ||
"test": [ | ||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||
for i, image in enumerate(images) | ||
] | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added length check for when num_validation_images
is zero, this throws an error trying to stack an empty array.
I had to set num_validation_images
to zero as the dummy pipeline throws an error during inference. Would be ideal to fix the dummy pipeline with the training script, but this is an ok workaround
}, | ||
) | ||
|
||
def test_text_to_image_checkpointing_checkpoints_total_limit(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every training script needs two tests:
One: that the marginal creation of a checkpoint that would place us over the limit deletes the earliest checkpoint
{"checkpoint-4", "checkpoint-6"}, | ||
) | ||
|
||
def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two: that restarting training with a lesser number of kept checkpoints will delete checkpoints starting from the oldest until we're at the new number of kept checkpoints
99d53c2
to
a34c45e
Compare
if accelerator.is_main_process: | ||
if global_step % args.checkpointing_steps == 0: | ||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | ||
if args.checkpoints_total_limit is not None: | ||
checkpoints = os.listdir(args.output_dir) | ||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | ||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | ||
|
||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | ||
if len(checkpoints) >= args.checkpoints_total_limit: | ||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 | ||
removing_checkpoints = checkpoints[0:num_to_remove] | ||
|
||
logger.info( | ||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | ||
) | ||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | ||
|
||
for removing_checkpoint in removing_checkpoints: | ||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) | ||
shutil.rmtree(removing_checkpoint) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just using this snippet in each training script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Would be great if we could try to remove this new argument: https://github.com/huggingface/diffusers/pull/3681/files#r1225441084
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I agree with Patrick in that it'd be awesome if we could remove that new argument if we can.
a34c45e
to
2a8cfd8
Compare
removed not-needed controlnet argument! |
…sers#3681) - add an lr_end parameter for setting that value - fix the use of lr_power
…uggingface#3681) * manual check for checkpoints_total_limit instead of using accelerate * remove controlnet_conditioning_embedding_out_channels
re: #2466 and #3652 and #3802
see PR comments