Skip to content

Commit e289998

Browse files
authored
fix mask discrepancies in train_dreambooth_inpaint (huggingface#1529)
The mask and instance image were being cropped in different ways without --center_crop, causing the model to learn to ignore the mask in some cases. This PR fixes that and generate more consistent results.
1 parent 634be6e commit e289998

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

examples/dreambooth/train_dreambooth_inpaint.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,15 @@ def __init__(
295295
else:
296296
self.class_data_root = None
297297

298-
self.image_transforms = transforms.Compose(
298+
self.image_transforms_resize_and_crop = transforms.Compose(
299299
[
300300
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
301301
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
302+
]
303+
)
304+
305+
self.image_transforms = transforms.Compose(
306+
[
302307
transforms.ToTensor(),
303308
transforms.Normalize([0.5], [0.5]),
304309
]
@@ -312,6 +317,7 @@ def __getitem__(self, index):
312317
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
313318
if not instance_image.mode == "RGB":
314319
instance_image = instance_image.convert("RGB")
320+
instance_image = self.image_transforms_resize_and_crop(instance_image)
315321

316322
example["PIL_images"] = instance_image
317323
example["instance_images"] = self.image_transforms(instance_image)
@@ -327,6 +333,7 @@ def __getitem__(self, index):
327333
class_image = Image.open(self.class_images_path[index % self.num_class_images])
328334
if not class_image.mode == "RGB":
329335
class_image = class_image.convert("RGB")
336+
class_image = self.image_transforms_resize_and_crop(class_image)
330337
example["class_images"] = self.image_transforms(class_image)
331338
example["class_PIL_images"] = class_image
332339
example["class_prompt_ids"] = self.tokenizer(
@@ -513,12 +520,6 @@ def main():
513520
)
514521

515522
def collate_fn(examples):
516-
image_transforms = transforms.Compose(
517-
[
518-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
519-
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
520-
]
521-
)
522523
input_ids = [example["instance_prompt_ids"] for example in examples]
523524
pixel_values = [example["instance_images"] for example in examples]
524525

@@ -535,9 +536,6 @@ def collate_fn(examples):
535536
pil_image = example["PIL_images"]
536537
# generate a random mask
537538
mask = random_mask(pil_image.size, 1, False)
538-
# apply transforms
539-
mask = image_transforms(mask)
540-
pil_image = image_transforms(pil_image)
541539
# prepare mask and masked image
542540
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
543541

@@ -548,9 +546,6 @@ def collate_fn(examples):
548546
for pil_image in pior_pil:
549547
# generate a random mask
550548
mask = random_mask(pil_image.size, 1, False)
551-
# apply transforms
552-
mask = image_transforms(mask)
553-
pil_image = image_transforms(pil_image)
554549
# prepare mask and masked image
555550
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
556551

0 commit comments

Comments
 (0)