Skip to content

manual check for checkpoints_total_limit instead of using accelerate #3681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 15, 2023

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Jun 5, 2023

re: #2466 and #3652 and #3802

see PR comments

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

This came off cleaner than #3652, no?

@williamberman
Copy link
Contributor Author

Nice!

This came off cleaner than #3652, no?

Eh, Imo while #3652 was a slightly bigger diff in the training script, I do prefer it to this since we'll now have to add tests for each training script that it's properly removing checkpoints.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf! Can we maybe apply this change to all other training scripts as well?

@williamberman
Copy link
Contributor Author

Perf! Can we maybe apply this change to all other training scripts as well?

Yep!

@williamberman williamberman force-pushed the enforce_total_limit branch 9 times, most recently from e903b71 to 99d53c2 Compare June 9, 2023 23:06
Comment on lines -906 to +933
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
if len(images) != 0:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
Copy link
Contributor Author

@williamberman williamberman Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added length check for when num_validation_images is zero, this throws an error trying to stack an empty array.
I had to set num_validation_images to zero as the dummy pipeline throws an error during inference. Would be ideal to fix the dummy pipeline with the training script, but this is an ok workaround

},
)

def test_text_to_image_checkpointing_checkpoints_total_limit(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every training script needs two tests:

One: that the marginal creation of a checkpoint that would place us over the limit deletes the earliest checkpoint

{"checkpoint-4", "checkpoint-6"},
)

def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two: that restarting training with a lesser number of kept checkpoints will delete checkpoints starting from the oldest until we're at the new number of kept checkpoints

Comment on lines 1068 to +1077
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]

logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just using this snippet in each training script

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Would be great if we could try to remove this new argument: https://github.com/huggingface/diffusers/pull/3681/files#r1225441084

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I agree with Patrick in that it'd be awesome if we could remove that new argument if we can.

@williamberman
Copy link
Contributor Author

removed not-needed controlnet argument!

@williamberman williamberman merged commit d49e2dd into huggingface:main Jun 15, 2023
bghira pushed a commit to bghira/SimpleTuner that referenced this pull request Jun 18, 2023
…sers#3681)

- add an lr_end parameter for setting that value

- fix the use of lr_power
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…uggingface#3681)

* manual check for checkpoints_total_limit instead of using accelerate

* remove controlnet_conditioning_embedding_out_channels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants