Skip to content

Use CC12M for LCM WDS training example #5908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 6, 2023
Merged

Use CC12M for LCM WDS training example #5908

merged 6 commits into from
Dec 6, 2023

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Nov 23, 2023

Had to make some adjustments to make it work with CC12M, as the metadata field names in the json are different. In particular, if no pwatermark field exists all images were rejected.

cc @stevhliu @sayakpaul @patil-suraj

Fixes #5868, #5770.
Possibly fixes #5743 for other datasets.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changes, they look good to me. Would it make sense to reflect the peft specific changes from #5778 to make it run more efficiently?

@pcuenca
Copy link
Member Author

pcuenca commented Dec 4, 2023

Would it make sense to reflect the peft specific changes from #5778 to make it run more efficiently?

Makes sense. I'd rather tackle that in a separate PR, if possible, so this one can be merged and helps the community with the issues linked in the first comment.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! @patil-suraj @dg845 can you please check here?

@@ -1097,7 +1097,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text, _, _ = batch
image, text = batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, is this change because the SD distillation script currently has a bug where it assumes that the Text2ImageDataset.train_dataloader batch consists of (image, text, orig_size, crop_coords) like the SDXL dataloader instead of just (image, text):

wds.map(filter_keys({"image", "text"})),
wds.map(transform),
wds.to_tuple("image", "text"),

and is otherwise unrelated to CC12M support?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like 711e468 mentions this. Feel free to close :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly, it's actually a bug and not directly related to CC12M :)

export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the hyperparameters in the examples work well with the CC12M dataset or does it potentially make sense to revisit them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to be honest. I haven't experimented much with it, considering that CC12M is very small for current standards.

Perhaps it would be easier to include a disclaimer stating that this dataset is used for illustrative purposes and users are encouraged to bring their own.

@dg845
Copy link
Contributor

dg845 commented Dec 4, 2023

I think it could be valuable to make sure WebdatasetFilter and Text2ImageDataset are consistent across the scripts (and maybe separating them into SD and SDXL versions, e.g. SDText2ImageDataset/SDXLText2ImageDataset), perhaps through the # Copied from mechanism. Not sure if it makes sense to do that in this PR though.

@pcuenca
Copy link
Member Author

pcuenca commented Dec 5, 2023

I think it could be valuable to make sure WebdatasetFilter and Text2ImageDataset are consistent across the scripts (and maybe separating them into SD and SDXL versions, e.g. SDText2ImageDataset/SDXLText2ImageDataset), perhaps through the # Copied from mechanism. Not sure if it makes sense to do that in this PR though.

I think that's a good idea! I'd rather work on that separately, if possible, so we can close those issues from the community.

I'll add a disclaimer about the illustrative nature of the dataset. Edit: done, @dg845 let me know if that'd be enough.

@pcuenca pcuenca requested a review from dg845 December 5, 2023 10:11
Copy link
Contributor

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me :).

@pcuenca
Copy link
Member Author

pcuenca commented Dec 6, 2023

Thanks for the great reviews!

@pcuenca pcuenca merged commit ab6672f into main Dec 6, 2023
@pcuenca pcuenca deleted the lcm-wds-example branch December 6, 2023 09:35
@sayakpaul sayakpaul mentioned this pull request Dec 11, 2023
donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 18, 2023
* Fix SD scripts - there are only 2 items per batch

* Adjustments to make the SDXL scripts work with other datasets

* Use public webdataset dataset for examples

* make style

* Minor tweaks to the readmes.

* Stress that the database is illustrative.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fix SD scripts - there are only 2 items per batch

* Adjustments to make the SDXL scripts work with other datasets

* Use public webdataset dataset for examples

* make style

* Minor tweaks to the readmes.

* Stress that the database is illustrative.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dataset access of LCM training [Latent Consistency Distillation] training stuck at 0%
5 participants