2727import torch
2828import torch .utils .checkpoint
2929import transformers
30- from datasets import load_dataset
30+ from datasets import load_dataset , load_from_disk
3131from flax import jax_utils
3232from flax .core .frozen_dict import unfreeze
3333from flax .training import train_state
3434from flax .training .common_utils import shard
3535from huggingface_hub import create_repo , upload_folder
36- from PIL import Image
36+ from PIL import Image , PngImagePlugin
3737from torch .utils .data import IterableDataset
3838from torchvision import transforms
3939from tqdm .auto import tqdm
4949from diffusers .utils import check_min_version , is_wandb_available
5050
5151
52+ # To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
53+ # see more https://github.com/python-pillow/Pillow/issues/5610
54+ LARGE_ENOUGH_NUMBER = 100
55+ PngImagePlugin .MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024 ** 2 )
56+
5257if is_wandb_available ():
5358 import wandb
5459
@@ -246,6 +251,12 @@ def parse_args():
246251 default = None ,
247252 help = "Total number of training steps to perform." ,
248253 )
254+ parser .add_argument (
255+ "--checkpointing_steps" ,
256+ type = int ,
257+ default = 5000 ,
258+ help = ("Save a checkpoint of the training state every X updates." ),
259+ )
249260 parser .add_argument (
250261 "--learning_rate" ,
251262 type = float ,
@@ -344,9 +355,17 @@ def parse_args():
344355 type = str ,
345356 default = None ,
346357 help = (
347- "A folder containing the training data. Folder contents must follow the structure described in"
348- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
349- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
358+ "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
359+ "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
360+ "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
361+ ),
362+ )
363+ parser .add_argument (
364+ "--load_from_disk" ,
365+ action = "store_true" ,
366+ help = (
367+ "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
368+ "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
350369 ),
351370 )
352371 parser .add_argument (
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
478497 )
479498 else :
480499 if args .train_data_dir is not None :
481- dataset = load_dataset (
482- args .train_data_dir ,
483- cache_dir = args .cache_dir ,
484- )
500+ if args .load_from_disk :
501+ dataset = load_from_disk (
502+ args .train_data_dir ,
503+ )
504+ else :
505+ dataset = load_dataset (
506+ args .train_data_dir ,
507+ cache_dir = args .cache_dir ,
508+ )
485509 # See more about loading custom images at
486510 # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
487511
@@ -545,6 +569,7 @@ def tokenize_captions(examples, is_train=True):
545569 image_transforms = transforms .Compose (
546570 [
547571 transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
572+ transforms .CenterCrop (args .resolution ),
548573 transforms .ToTensor (),
549574 transforms .Normalize ([0.5 ], [0.5 ]),
550575 ]
@@ -553,6 +578,7 @@ def tokenize_captions(examples, is_train=True):
553578 conditioning_image_transforms = transforms .Compose (
554579 [
555580 transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
581+ transforms .CenterCrop (args .resolution ),
556582 transforms .ToTensor (),
557583 ]
558584 )
@@ -981,6 +1007,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
9811007 "train/loss" : jax_utils .unreplicate (train_metric )["loss" ],
9821008 }
9831009 )
1010+ if global_step % args .checkpointing_steps == 0 and jax .process_index () == 0 :
1011+ controlnet .save_pretrained (
1012+ f"{ args .output_dir } /{ global_step } " ,
1013+ params = get_params_to_save (state .params ),
1014+ )
9841015
9851016 train_metric = jax_utils .unreplicate (train_metric )
9861017 train_step_progress_bar .close ()
0 commit comments