@@ -304,9 +304,10 @@ def __getitem__(self, index):
304
304
example ["instance_images" ] = self .image_transforms (instance_image )
305
305
example ["instance_prompt_ids" ] = self .tokenizer (
306
306
self .instance_prompt ,
307
- padding = "do_not_pad" ,
308
307
truncation = True ,
308
+ padding = "max_length" ,
309
309
max_length = self .tokenizer .model_max_length ,
310
+ return_tensors = "pt" ,
310
311
).input_ids
311
312
312
313
if self .class_data_root :
@@ -316,14 +317,37 @@ def __getitem__(self, index):
316
317
example ["class_images" ] = self .image_transforms (class_image )
317
318
example ["class_prompt_ids" ] = self .tokenizer (
318
319
self .class_prompt ,
319
- padding = "do_not_pad" ,
320
320
truncation = True ,
321
+ padding = "max_length" ,
321
322
max_length = self .tokenizer .model_max_length ,
323
+ return_tensors = "pt" ,
322
324
).input_ids
323
325
324
326
return example
325
327
326
328
329
+ def collate_fn (examples , with_prior_preservation = False ):
330
+ input_ids = [example ["instance_prompt_ids" ] for example in examples ]
331
+ pixel_values = [example ["instance_images" ] for example in examples ]
332
+
333
+ # Concat class and instance examples for prior preservation.
334
+ # We do this to avoid doing two forward passes.
335
+ if with_prior_preservation :
336
+ input_ids += [example ["class_prompt_ids" ] for example in examples ]
337
+ pixel_values += [example ["class_images" ] for example in examples ]
338
+
339
+ pixel_values = torch .stack (pixel_values )
340
+ pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
341
+
342
+ input_ids = torch .cat (input_ids , dim = 0 )
343
+
344
+ batch = {
345
+ "input_ids" : input_ids ,
346
+ "pixel_values" : pixel_values ,
347
+ }
348
+ return batch
349
+
350
+
327
351
class PromptDataset (Dataset ):
328
352
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
329
353
@@ -514,34 +538,12 @@ def main(args):
514
538
center_crop = args .center_crop ,
515
539
)
516
540
517
- def collate_fn (examples ):
518
- input_ids = [example ["instance_prompt_ids" ] for example in examples ]
519
- pixel_values = [example ["instance_images" ] for example in examples ]
520
-
521
- # Concat class and instance examples for prior preservation.
522
- # We do this to avoid doing two forward passes.
523
- if args .with_prior_preservation :
524
- input_ids += [example ["class_prompt_ids" ] for example in examples ]
525
- pixel_values += [example ["class_images" ] for example in examples ]
526
-
527
- pixel_values = torch .stack (pixel_values )
528
- pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
529
-
530
- input_ids = tokenizer .pad (
531
- {"input_ids" : input_ids },
532
- padding = "max_length" ,
533
- max_length = tokenizer .model_max_length ,
534
- return_tensors = "pt" ,
535
- ).input_ids
536
-
537
- batch = {
538
- "input_ids" : input_ids ,
539
- "pixel_values" : pixel_values ,
540
- }
541
- return batch
542
-
543
541
train_dataloader = torch .utils .data .DataLoader (
544
- train_dataset , batch_size = args .train_batch_size , shuffle = True , collate_fn = collate_fn , num_workers = 1
542
+ train_dataset ,
543
+ batch_size = args .train_batch_size ,
544
+ shuffle = True ,
545
+ collate_fn = lambda examples : collate_fn (examples , args .with_prior_preservation ),
546
+ num_workers = 1 ,
545
547
)
546
548
547
549
# Scheduler and math around the number of training steps.
0 commit comments