@@ -451,19 +451,18 @@ def main():
451451 # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
452452 # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
453453 # initialized to zero.
454- if accelerator .is_main_process :
455- logger .info ("Initializing the InstructPix2Pix UNet from the pretrained UNet." )
456- in_channels = 8
457- out_channels = unet .conv_in .out_channels
458- unet .register_to_config (in_channels = in_channels )
459-
460- with torch .no_grad ():
461- new_conv_in = nn .Conv2d (
462- in_channels , out_channels , unet .conv_in .kernel_size , unet .conv_in .stride , unet .conv_in .padding
463- )
464- new_conv_in .weight .zero_ ()
465- new_conv_in .weight [:, :4 , :, :].copy_ (unet .conv_in .weight )
466- unet .conv_in = new_conv_in
454+ logger .info ("Initializing the InstructPix2Pix UNet from the pretrained UNet." )
455+ in_channels = 8
456+ out_channels = unet .conv_in .out_channels
457+ unet .register_to_config (in_channels = in_channels )
458+
459+ with torch .no_grad ():
460+ new_conv_in = nn .Conv2d (
461+ in_channels , out_channels , unet .conv_in .kernel_size , unet .conv_in .stride , unet .conv_in .padding
462+ )
463+ new_conv_in .weight .zero_ ()
464+ new_conv_in .weight [:, :4 , :, :].copy_ (unet .conv_in .weight )
465+ unet .conv_in = new_conv_in
467466
468467 # Freeze vae and text_encoder
469468 vae .requires_grad_ (False )
@@ -892,9 +891,12 @@ def collate_fn(examples):
892891 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
893892 ema_unet .store (unet .parameters ())
894893 ema_unet .copy_to (unet .parameters ())
894+ # The models need unwrapping because for compatibility in distributed training mode.
895895 pipeline = StableDiffusionInstructPix2PixPipeline .from_pretrained (
896896 args .pretrained_model_name_or_path ,
897- unet = unet ,
897+ unet = accelerator .unwrap_model (unet ),
898+ text_encoder = accelerator .unwrap_model (text_encoder ),
899+ vae = accelerator .unwrap_model (vae ),
898900 revision = args .revision ,
899901 torch_dtype = weight_dtype ,
900902 )
@@ -904,7 +906,9 @@ def collate_fn(examples):
904906 # run inference
905907 original_image = download_image (args .val_image_url )
906908 edited_images = []
907- with torch .autocast (str (accelerator .device ), enabled = accelerator .mixed_precision == "fp16" ):
909+ with torch .autocast (
910+ str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
911+ ):
908912 for _ in range (args .num_validation_images ):
909913 edited_images .append (
910914 pipeline (
@@ -959,7 +963,7 @@ def collate_fn(examples):
959963 if args .validation_prompt is not None :
960964 edited_images = []
961965 pipeline = pipeline .to (accelerator .device )
962- with torch .autocast (str (accelerator .device )):
966+ with torch .autocast (str (accelerator .device ). replace ( ":0" , "" ) ):
963967 for _ in range (args .num_validation_images ):
964968 edited_images .append (
965969 pipeline (
0 commit comments