@@ -493,6 +493,14 @@ def parse_args(input_args=None):
493493 default = "conditioning_image" ,
494494 help = "The column of the dataset containing the controlnet conditioning image." ,
495495 )
496+
497+ parser .add_argument (
498+ "--conditioning_image_alt_column" ,
499+ type = str ,
500+ default = "conditioning_image_alt" ,
501+ help = "The column of the dataset containing the alternative controlnet conditioning image." ,
502+ )
503+
496504 parser .add_argument (
497505 "--caption_column" ,
498506 type = str ,
@@ -658,6 +666,20 @@ def make_train_dataset(args, tokenizer, accelerator):
658666 f"`--conditioning_image_column` value '{ args .conditioning_image_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
659667 )
660668
669+ if args .conditioning_image_alt_column is None :
670+ if len (column_names ) > 3 :
671+ conditioning_image_alt_column = column_names [3 ]
672+ logger .info (f"conditioning image alt column defaulting to { conditioning_image_alt_column } " )
673+ else :
674+ conditioning_image_alt_column = None
675+ logger .info ("No conditioning image alt column found in dataset" )
676+ else :
677+ conditioning_image_alt_column = args .conditioning_image_alt_column
678+ if conditioning_image_alt_column not in column_names :
679+ raise ValueError (
680+ f"`--conditioning_image_alt_column` value '{ args .conditioning_image_alt_column } ' not found in dataset columns. Dataset columns are: { ', ' .join (column_names )} "
681+ )
682+
661683 def tokenize_captions (examples , is_train = True ):
662684 captions = []
663685 for caption in examples [caption_column ]:
@@ -701,8 +723,12 @@ def preprocess_train(examples):
701723 conditioning_images = [image .convert ("RGB" ) for image in examples [conditioning_image_column ]]
702724 conditioning_images = [conditioning_image_transforms (image ) for image in conditioning_images ]
703725
726+ conditioning_images_alt = [image .convert ("RGB" ) for image in examples [conditioning_image_alt_column ]]
727+ conditioning_images_alt = [conditioning_image_transforms (image ) for image in conditioning_images_alt ]
728+
704729 examples ["pixel_values" ] = images
705730 examples ["conditioning_pixel_values" ] = conditioning_images
731+ examples ["conditioning_pixel_values_alt" ] = conditioning_images_alt
706732 examples ["input_ids" ] = tokenize_captions (examples )
707733
708734 return examples
@@ -723,11 +749,15 @@ def collate_fn(examples):
723749 conditioning_pixel_values = torch .stack ([example ["conditioning_pixel_values" ] for example in examples ])
724750 conditioning_pixel_values = conditioning_pixel_values .to (memory_format = torch .contiguous_format ).float ()
725751
752+ conditioning_pixel_values_alt = torch .stack ([example ["conditioning_pixel_values_alt" ] for example in examples ])
753+ conditioning_pixel_values_alt = conditioning_pixel_values_alt .to (memory_format = torch .contiguous_format ).float ()
754+
726755 input_ids = torch .stack ([example ["input_ids" ] for example in examples ])
727756
728757 return {
729758 "pixel_values" : pixel_values ,
730759 "conditioning_pixel_values" : conditioning_pixel_values ,
760+ "conditioning_pixel_values_alt" : conditioning_pixel_values_alt ,
731761 "input_ids" : input_ids ,
732762 }
733763
@@ -798,12 +828,16 @@ def main(args):
798828
799829 # Load scheduler and models
800830 noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
831+
801832 text_encoder = text_encoder_cls .from_pretrained (
802833 args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision , variant = args .variant
803834 )
804- vae = AutoencoderKL .from_pretrained (
805- args .pretrained_model_name_or_path , subfolder = "vae" , revision = args .revision , variant = args .variant
806- )
835+ # vae = AutoencoderKL.from_pretrained(
836+ # args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
837+ # )
838+
839+ vae = AutoencoderKL .from_pretrained ('stabilityai/sd-vae-ft-mse' )
840+
807841 unet = UNet2DConditionModel .from_pretrained (
808842 args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
809843 )
@@ -1055,11 +1089,14 @@ def load_model_hook(models, input_dir):
10551089
10561090 controlnet_image = batch ["conditioning_pixel_values" ].to (dtype = weight_dtype )
10571091
1092+ controlnet_image_alt = batch ["conditioning_pixel_values_alt" ].to (dtype = weight_dtype )
1093+
10581094 down_block_res_samples , mid_block_res_sample = controlnet (
10591095 noisy_latents ,
10601096 timesteps ,
10611097 encoder_hidden_states = encoder_hidden_states ,
10621098 controlnet_cond = controlnet_image ,
1099+ # controlnet_cond_alt=controlnet_image_alt,
10631100 return_dict = False ,
10641101 )
10651102
0 commit comments