@@ -939,6 +939,32 @@ def __init__(
939939 self .class_data_root = Path (class_data_root )
940940 self .class_data_root .mkdir (parents = True , exist_ok = True )
941941 self .class_images_path = list (self .class_data_root .iterdir ())
942+
943+ self .original_sizes_class_imgs = []
944+ self .crop_top_lefts_class_imgs = []
945+ self .pixel_values_class_imgs = []
946+ self .class_images = [Image .open (path ) for path in self .class_images_path ]
947+ for image in self .class_images :
948+ image = exif_transpose (image )
949+ if not image .mode == "RGB" :
950+ image = image .convert ("RGB" )
951+ self .original_sizes_class_imgs .append ((image .height , image .width ))
952+ image = train_resize (image )
953+ if args .random_flip and random .random () < 0.5 :
954+ # flip
955+ image = train_flip (image )
956+ if args .center_crop :
957+ y1 = max (0 , int (round ((image .height - args .resolution ) / 2.0 )))
958+ x1 = max (0 , int (round ((image .width - args .resolution ) / 2.0 )))
959+ image = train_crop (image )
960+ else :
961+ y1 , x1 , h , w = train_crop .get_params (image , (args .resolution , args .resolution ))
962+ image = crop (image , y1 , x1 , h , w )
963+ crop_top_left = (y1 , x1 )
964+ self .crop_top_lefts_class_imgs .append (crop_top_left )
965+ image = train_transforms (image )
966+ self .pixel_values_class_imgs .append (image )
967+
942968 if class_num is not None :
943969 self .num_class_images = min (len (self .class_images_path ), class_num )
944970 else :
@@ -961,12 +987,9 @@ def __len__(self):
961987
962988 def __getitem__ (self , index ):
963989 example = {}
964- instance_image = self .pixel_values [index % self .num_instance_images ]
965- original_size = self .original_sizes [index % self .num_instance_images ]
966- crop_top_left = self .crop_top_lefts [index % self .num_instance_images ]
967- example ["instance_images" ] = instance_image
968- example ["original_size" ] = original_size
969- example ["crop_top_left" ] = crop_top_left
990+ example ["instance_images" ] = self .pixel_values [index % self .num_instance_images ]
991+ example ["original_size" ] = self .original_sizes [index % self .num_instance_images ]
992+ example ["crop_top_left" ] = self .crop_top_lefts [index % self .num_instance_images ]
970993
971994 if self .custom_instance_prompts :
972995 caption = self .custom_instance_prompts [index % self .num_instance_images ]
@@ -983,13 +1006,10 @@ def __getitem__(self, index):
9831006 example ["instance_prompt" ] = self .instance_prompt
9841007
9851008 if self .class_data_root :
986- class_image = Image .open (self .class_images_path [index % self .num_class_images ])
987- class_image = exif_transpose (class_image )
988-
989- if not class_image .mode == "RGB" :
990- class_image = class_image .convert ("RGB" )
991- example ["class_images" ] = self .image_transforms (class_image )
9921009 example ["class_prompt" ] = self .class_prompt
1010+ example ["class_images" ] = self .pixel_values_class_imgs [index % self .num_class_images ]
1011+ example ["class_original_size" ] = self .original_sizes_class_imgs [index % self .num_class_images ]
1012+ example ["class_crop_top_left" ] = self .crop_top_lefts_class_imgs [index % self .num_class_images ]
9931013
9941014 return example
9951015
@@ -1005,6 +1025,8 @@ def collate_fn(examples, with_prior_preservation=False):
10051025 if with_prior_preservation :
10061026 pixel_values += [example ["class_images" ] for example in examples ]
10071027 prompts += [example ["class_prompt" ] for example in examples ]
1028+ original_sizes += [example ["class_original_size" ] for example in examples ]
1029+ crop_top_lefts += [example ["class_crop_top_left" ] for example in examples ]
10081030
10091031 pixel_values = torch .stack (pixel_values )
10101032 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
0 commit comments