Skip to content

Commit 9e11029

Browse files
authored
[dreambooth] make collate_fn global (huggingface#1547)
make collate_fn global
1 parent c228331 commit 9e11029

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

examples/dreambooth/train_dreambooth.py

+31-29
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,10 @@ def __getitem__(self, index):
304304
example["instance_images"] = self.image_transforms(instance_image)
305305
example["instance_prompt_ids"] = self.tokenizer(
306306
self.instance_prompt,
307-
padding="do_not_pad",
308307
truncation=True,
308+
padding="max_length",
309309
max_length=self.tokenizer.model_max_length,
310+
return_tensors="pt",
310311
).input_ids
311312

312313
if self.class_data_root:
@@ -316,14 +317,37 @@ def __getitem__(self, index):
316317
example["class_images"] = self.image_transforms(class_image)
317318
example["class_prompt_ids"] = self.tokenizer(
318319
self.class_prompt,
319-
padding="do_not_pad",
320320
truncation=True,
321+
padding="max_length",
321322
max_length=self.tokenizer.model_max_length,
323+
return_tensors="pt",
322324
).input_ids
323325

324326
return example
325327

326328

329+
def collate_fn(examples, with_prior_preservation=False):
330+
input_ids = [example["instance_prompt_ids"] for example in examples]
331+
pixel_values = [example["instance_images"] for example in examples]
332+
333+
# Concat class and instance examples for prior preservation.
334+
# We do this to avoid doing two forward passes.
335+
if with_prior_preservation:
336+
input_ids += [example["class_prompt_ids"] for example in examples]
337+
pixel_values += [example["class_images"] for example in examples]
338+
339+
pixel_values = torch.stack(pixel_values)
340+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
341+
342+
input_ids = torch.cat(input_ids, dim=0)
343+
344+
batch = {
345+
"input_ids": input_ids,
346+
"pixel_values": pixel_values,
347+
}
348+
return batch
349+
350+
327351
class PromptDataset(Dataset):
328352
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
329353

@@ -514,34 +538,12 @@ def main(args):
514538
center_crop=args.center_crop,
515539
)
516540

517-
def collate_fn(examples):
518-
input_ids = [example["instance_prompt_ids"] for example in examples]
519-
pixel_values = [example["instance_images"] for example in examples]
520-
521-
# Concat class and instance examples for prior preservation.
522-
# We do this to avoid doing two forward passes.
523-
if args.with_prior_preservation:
524-
input_ids += [example["class_prompt_ids"] for example in examples]
525-
pixel_values += [example["class_images"] for example in examples]
526-
527-
pixel_values = torch.stack(pixel_values)
528-
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
529-
530-
input_ids = tokenizer.pad(
531-
{"input_ids": input_ids},
532-
padding="max_length",
533-
max_length=tokenizer.model_max_length,
534-
return_tensors="pt",
535-
).input_ids
536-
537-
batch = {
538-
"input_ids": input_ids,
539-
"pixel_values": pixel_values,
540-
}
541-
return batch
542-
543541
train_dataloader = torch.utils.data.DataLoader(
544-
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
542+
train_dataset,
543+
batch_size=args.train_batch_size,
544+
shuffle=True,
545+
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
546+
num_workers=1,
545547
)
546548

547549
# Scheduler and math around the number of training steps.

0 commit comments

Comments
 (0)