Skip to content

Commit ed43f58

Browse files
authored
Update train_controlnet.py
1 parent 1c2d6b5 commit ed43f58

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,14 @@ def parse_args(input_args=None):
493493
default="conditioning_image",
494494
help="The column of the dataset containing the controlnet conditioning image.",
495495
)
496+
497+
parser.add_argument(
498+
"--conditioning_image_alt_column",
499+
type=str,
500+
default="conditioning_image_alt",
501+
help="The column of the dataset containing the alternative controlnet conditioning image.",
502+
)
503+
496504
parser.add_argument(
497505
"--caption_column",
498506
type=str,
@@ -658,6 +666,20 @@ def make_train_dataset(args, tokenizer, accelerator):
658666
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
659667
)
660668

669+
if args.conditioning_image_alt_column is None:
670+
if len(column_names) > 3:
671+
conditioning_image_alt_column = column_names[3]
672+
logger.info(f"conditioning image alt column defaulting to {conditioning_image_alt_column}")
673+
else:
674+
conditioning_image_alt_column = None
675+
logger.info("No conditioning image alt column found in dataset")
676+
else:
677+
conditioning_image_alt_column = args.conditioning_image_alt_column
678+
if conditioning_image_alt_column not in column_names:
679+
raise ValueError(
680+
f"`--conditioning_image_alt_column` value '{args.conditioning_image_alt_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
681+
)
682+
661683
def tokenize_captions(examples, is_train=True):
662684
captions = []
663685
for caption in examples[caption_column]:
@@ -701,8 +723,12 @@ def preprocess_train(examples):
701723
conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
702724
conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
703725

726+
conditioning_images_alt = [image.convert("RGB") for image in examples[conditioning_image_alt_column]]
727+
conditioning_images_alt = [conditioning_image_transforms(image) for image in conditioning_images_alt]
728+
704729
examples["pixel_values"] = images
705730
examples["conditioning_pixel_values"] = conditioning_images
731+
examples["conditioning_pixel_values_alt"] = conditioning_images_alt
706732
examples["input_ids"] = tokenize_captions(examples)
707733

708734
return examples
@@ -723,11 +749,15 @@ def collate_fn(examples):
723749
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
724750
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
725751

752+
conditioning_pixel_values_alt = torch.stack([example["conditioning_pixel_values_alt"] for example in examples])
753+
conditioning_pixel_values_alt = conditioning_pixel_values_alt.to(memory_format=torch.contiguous_format).float()
754+
726755
input_ids = torch.stack([example["input_ids"] for example in examples])
727756

728757
return {
729758
"pixel_values": pixel_values,
730759
"conditioning_pixel_values": conditioning_pixel_values,
760+
"conditioning_pixel_values_alt": conditioning_pixel_values_alt,
731761
"input_ids": input_ids,
732762
}
733763

@@ -798,12 +828,16 @@ def main(args):
798828

799829
# Load scheduler and models
800830
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
831+
801832
text_encoder = text_encoder_cls.from_pretrained(
802833
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
803834
)
804-
vae = AutoencoderKL.from_pretrained(
805-
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
806-
)
835+
# vae = AutoencoderKL.from_pretrained(
836+
# args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
837+
# )
838+
839+
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse')
840+
807841
unet = UNet2DConditionModel.from_pretrained(
808842
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
809843
)
@@ -1055,11 +1089,14 @@ def load_model_hook(models, input_dir):
10551089

10561090
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
10571091

1092+
controlnet_image_alt = batch["conditioning_pixel_values_alt"].to(dtype=weight_dtype)
1093+
10581094
down_block_res_samples, mid_block_res_sample = controlnet(
10591095
noisy_latents,
10601096
timesteps,
10611097
encoder_hidden_states=encoder_hidden_states,
10621098
controlnet_cond=controlnet_image,
1099+
# controlnet_cond_alt=controlnet_image_alt,
10631100
return_dict=False,
10641101
)
10651102

0 commit comments

Comments
 (0)