@@ -295,10 +295,15 @@ def __init__(
295
295
else :
296
296
self .class_data_root = None
297
297
298
- self .image_transforms = transforms .Compose (
298
+ self .image_transforms_resize_and_crop = transforms .Compose (
299
299
[
300
300
transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR ),
301
301
transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
302
+ ]
303
+ )
304
+
305
+ self .image_transforms = transforms .Compose (
306
+ [
302
307
transforms .ToTensor (),
303
308
transforms .Normalize ([0.5 ], [0.5 ]),
304
309
]
@@ -312,6 +317,7 @@ def __getitem__(self, index):
312
317
instance_image = Image .open (self .instance_images_path [index % self .num_instance_images ])
313
318
if not instance_image .mode == "RGB" :
314
319
instance_image = instance_image .convert ("RGB" )
320
+ instance_image = self .image_transforms_resize_and_crop (instance_image )
315
321
316
322
example ["PIL_images" ] = instance_image
317
323
example ["instance_images" ] = self .image_transforms (instance_image )
@@ -327,6 +333,7 @@ def __getitem__(self, index):
327
333
class_image = Image .open (self .class_images_path [index % self .num_class_images ])
328
334
if not class_image .mode == "RGB" :
329
335
class_image = class_image .convert ("RGB" )
336
+ class_image = self .image_transforms_resize_and_crop (class_image )
330
337
example ["class_images" ] = self .image_transforms (class_image )
331
338
example ["class_PIL_images" ] = class_image
332
339
example ["class_prompt_ids" ] = self .tokenizer (
@@ -513,12 +520,6 @@ def main():
513
520
)
514
521
515
522
def collate_fn (examples ):
516
- image_transforms = transforms .Compose (
517
- [
518
- transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
519
- transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution ),
520
- ]
521
- )
522
523
input_ids = [example ["instance_prompt_ids" ] for example in examples ]
523
524
pixel_values = [example ["instance_images" ] for example in examples ]
524
525
@@ -535,9 +536,6 @@ def collate_fn(examples):
535
536
pil_image = example ["PIL_images" ]
536
537
# generate a random mask
537
538
mask = random_mask (pil_image .size , 1 , False )
538
- # apply transforms
539
- mask = image_transforms (mask )
540
- pil_image = image_transforms (pil_image )
541
539
# prepare mask and masked image
542
540
mask , masked_image = prepare_mask_and_masked_image (pil_image , mask )
543
541
@@ -548,9 +546,6 @@ def collate_fn(examples):
548
546
for pil_image in pior_pil :
549
547
# generate a random mask
550
548
mask = random_mask (pil_image .size , 1 , False )
551
- # apply transforms
552
- mask = image_transforms (mask )
553
- pil_image = image_transforms (pil_image )
554
549
# prepare mask and masked image
555
550
mask , masked_image = prepare_mask_and_masked_image (pil_image , mask )
556
551
0 commit comments