Skip to content

Commit 8f2c7b4

Browse files
authored
[advanced sdxl lora script] - fix huggingface#6967 bug when using prior preservation loss (huggingface#6968)
* fix bug in micro-conditioning of class images * fix bug in micro-conditioning of class images * style
1 parent 2e387da commit 8f2c7b4

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,32 @@ def __init__(
939939
self.class_data_root = Path(class_data_root)
940940
self.class_data_root.mkdir(parents=True, exist_ok=True)
941941
self.class_images_path = list(self.class_data_root.iterdir())
942+
943+
self.original_sizes_class_imgs = []
944+
self.crop_top_lefts_class_imgs = []
945+
self.pixel_values_class_imgs = []
946+
self.class_images = [Image.open(path) for path in self.class_images_path]
947+
for image in self.class_images:
948+
image = exif_transpose(image)
949+
if not image.mode == "RGB":
950+
image = image.convert("RGB")
951+
self.original_sizes_class_imgs.append((image.height, image.width))
952+
image = train_resize(image)
953+
if args.random_flip and random.random() < 0.5:
954+
# flip
955+
image = train_flip(image)
956+
if args.center_crop:
957+
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
958+
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
959+
image = train_crop(image)
960+
else:
961+
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
962+
image = crop(image, y1, x1, h, w)
963+
crop_top_left = (y1, x1)
964+
self.crop_top_lefts_class_imgs.append(crop_top_left)
965+
image = train_transforms(image)
966+
self.pixel_values_class_imgs.append(image)
967+
942968
if class_num is not None:
943969
self.num_class_images = min(len(self.class_images_path), class_num)
944970
else:
@@ -961,12 +987,9 @@ def __len__(self):
961987

962988
def __getitem__(self, index):
963989
example = {}
964-
instance_image = self.pixel_values[index % self.num_instance_images]
965-
original_size = self.original_sizes[index % self.num_instance_images]
966-
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
967-
example["instance_images"] = instance_image
968-
example["original_size"] = original_size
969-
example["crop_top_left"] = crop_top_left
990+
example["instance_images"] = self.pixel_values[index % self.num_instance_images]
991+
example["original_size"] = self.original_sizes[index % self.num_instance_images]
992+
example["crop_top_left"] = self.crop_top_lefts[index % self.num_instance_images]
970993

971994
if self.custom_instance_prompts:
972995
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -983,13 +1006,10 @@ def __getitem__(self, index):
9831006
example["instance_prompt"] = self.instance_prompt
9841007

9851008
if self.class_data_root:
986-
class_image = Image.open(self.class_images_path[index % self.num_class_images])
987-
class_image = exif_transpose(class_image)
988-
989-
if not class_image.mode == "RGB":
990-
class_image = class_image.convert("RGB")
991-
example["class_images"] = self.image_transforms(class_image)
9921009
example["class_prompt"] = self.class_prompt
1010+
example["class_images"] = self.pixel_values_class_imgs[index % self.num_class_images]
1011+
example["class_original_size"] = self.original_sizes_class_imgs[index % self.num_class_images]
1012+
example["class_crop_top_left"] = self.crop_top_lefts_class_imgs[index % self.num_class_images]
9931013

9941014
return example
9951015

@@ -1005,6 +1025,8 @@ def collate_fn(examples, with_prior_preservation=False):
10051025
if with_prior_preservation:
10061026
pixel_values += [example["class_images"] for example in examples]
10071027
prompts += [example["class_prompt"] for example in examples]
1028+
original_sizes += [example["class_original_size"] for example in examples]
1029+
crop_top_lefts += [example["class_crop_top_left"] for example in examples]
10081030

10091031
pixel_values = torch.stack(pixel_values)
10101032
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

0 commit comments

Comments
 (0)