@@ -304,9 +304,10 @@ def __getitem__(self, index):
304304 example ["instance_images" ] = self .image_transforms (instance_image )
305305 example ["instance_prompt_ids" ] = self .tokenizer (
306306 self .instance_prompt ,
307- padding = "do_not_pad" ,
308307 truncation = True ,
308+ padding = "max_length" ,
309309 max_length = self .tokenizer .model_max_length ,
310+ return_tensors = "pt" ,
310311 ).input_ids
311312
312313 if self .class_data_root :
@@ -316,14 +317,37 @@ def __getitem__(self, index):
316317 example ["class_images" ] = self .image_transforms (class_image )
317318 example ["class_prompt_ids" ] = self .tokenizer (
318319 self .class_prompt ,
319- padding = "do_not_pad" ,
320320 truncation = True ,
321+ padding = "max_length" ,
321322 max_length = self .tokenizer .model_max_length ,
323+ return_tensors = "pt" ,
322324 ).input_ids
323325
324326 return example
325327
326328
329+ def collate_fn (examples , with_prior_preservation = False ):
330+ input_ids = [example ["instance_prompt_ids" ] for example in examples ]
331+ pixel_values = [example ["instance_images" ] for example in examples ]
332+
333+ # Concat class and instance examples for prior preservation.
334+ # We do this to avoid doing two forward passes.
335+ if with_prior_preservation :
336+ input_ids += [example ["class_prompt_ids" ] for example in examples ]
337+ pixel_values += [example ["class_images" ] for example in examples ]
338+
339+ pixel_values = torch .stack (pixel_values )
340+ pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
341+
342+ input_ids = torch .cat (input_ids , dim = 0 )
343+
344+ batch = {
345+ "input_ids" : input_ids ,
346+ "pixel_values" : pixel_values ,
347+ }
348+ return batch
349+
350+
327351class PromptDataset (Dataset ):
328352 "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
329353
@@ -514,34 +538,12 @@ def main(args):
514538 center_crop = args .center_crop ,
515539 )
516540
517- def collate_fn (examples ):
518- input_ids = [example ["instance_prompt_ids" ] for example in examples ]
519- pixel_values = [example ["instance_images" ] for example in examples ]
520-
521- # Concat class and instance examples for prior preservation.
522- # We do this to avoid doing two forward passes.
523- if args .with_prior_preservation :
524- input_ids += [example ["class_prompt_ids" ] for example in examples ]
525- pixel_values += [example ["class_images" ] for example in examples ]
526-
527- pixel_values = torch .stack (pixel_values )
528- pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
529-
530- input_ids = tokenizer .pad (
531- {"input_ids" : input_ids },
532- padding = "max_length" ,
533- max_length = tokenizer .model_max_length ,
534- return_tensors = "pt" ,
535- ).input_ids
536-
537- batch = {
538- "input_ids" : input_ids ,
539- "pixel_values" : pixel_values ,
540- }
541- return batch
542-
543541 train_dataloader = torch .utils .data .DataLoader (
544- train_dataset , batch_size = args .train_batch_size , shuffle = True , collate_fn = collate_fn , num_workers = 1
542+ train_dataset ,
543+ batch_size = args .train_batch_size ,
544+ shuffle = True ,
545+ collate_fn = lambda examples : collate_fn (examples , args .with_prior_preservation ),
546+ num_workers = 1 ,
545547 )
546548
547549 # Scheduler and math around the number of training steps.
0 commit comments