-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Use CC12M for LCM WDS training example #5908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for changes, they look good to me. Would it make sense to reflect the peft
specific changes from #5778 to make it run more efficiently?
Makes sense. I'd rather tackle that in a separate PR, if possible, so this one can be merged and helps the community with the issues linked in the first comment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! @patil-suraj @dg845 can you please check here?
@@ -1097,7 +1097,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok | |||
for epoch in range(first_epoch, args.num_train_epochs): | |||
for step, batch in enumerate(train_dataloader): | |||
with accelerator.accumulate(unet): | |||
image, text, _, _ = batch | |||
image, text = batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, is this change because the SD distillation script currently has a bug where it assumes that the Text2ImageDataset.train_dataloader
batch consists of (image, text, orig_size, crop_coords)
like the SDXL dataloader instead of just (image, text)
:
diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py
Lines 176 to 178 in 110ac7f
wds.map(filter_keys({"image", "text"})), | |
wds.map(transform), | |
wds.to_tuple("image", "text"), |
and is otherwise unrelated to CC12M support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like 711e468 mentions this. Feel free to close :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly, it's actually a bug and not directly related to CC12M :)
export OUTPUT_DIR="path/to/saved/model" | ||
|
||
accelerate launch train_lcm_distill_sd_wds.py \ | ||
--pretrained_teacher_model=$MODEL_NAME \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--mixed_precision=fp16 \ | ||
--resolution=512 \ | ||
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the hyperparameters in the examples work well with the CC12M dataset or does it potentially make sense to revisit them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure to be honest. I haven't experimented much with it, considering that CC12M is very small for current standards.
Perhaps it would be easier to include a disclaimer stating that this dataset is used for illustrative purposes and users are encouraged to bring their own.
I think it could be valuable to make sure |
I think that's a good idea! I'd rather work on that separately, if possible, so we can close those issues from the community. I'll add a disclaimer about the illustrative nature of the dataset. Edit: done, @dg845 let me know if that'd be enough. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me :).
Thanks for the great reviews! |
* Fix SD scripts - there are only 2 items per batch * Adjustments to make the SDXL scripts work with other datasets * Use public webdataset dataset for examples * make style * Minor tweaks to the readmes. * Stress that the database is illustrative.
* Fix SD scripts - there are only 2 items per batch * Adjustments to make the SDXL scripts work with other datasets * Use public webdataset dataset for examples * make style * Minor tweaks to the readmes. * Stress that the database is illustrative.
Had to make some adjustments to make it work with CC12M, as the metadata field names in the json are different. In particular, if no
pwatermark
field exists all images were rejected.cc @stevhliu @sayakpaul @patil-suraj
Fixes #5868, #5770.
Possibly fixes #5743 for other datasets.