@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
7474 unet = unet ,
7575 controlnet = controlnet ,
7676 revision = args .revision ,
77+ variant = args .variant ,
7778 torch_dtype = weight_dtype ,
7879 )
7980 pipeline .scheduler = UniPCMultistepScheduler .from_config (pipeline .scheduler .config )
@@ -243,15 +244,18 @@ def parse_args(input_args=None):
243244 help = "Path to pretrained controlnet model or model identifier from huggingface.co/models."
244245 " If not specified controlnet weights are initialized from unet." ,
245246 )
247+ parser .add_argument (
248+ "--variant" ,
249+ type = str ,
250+ default = None ,
251+ help = "Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16" ,
252+ )
246253 parser .add_argument (
247254 "--revision" ,
248255 type = str ,
249256 default = None ,
250257 required = False ,
251- help = (
252- "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
253- " float32 precision."
254- ),
258+ help = "Revision of pretrained model identifier from huggingface.co/models." ,
255259 )
256260 parser .add_argument (
257261 "--tokenizer_name" ,
@@ -793,10 +797,16 @@ def main(args):
793797
794798 # Load the tokenizers
795799 tokenizer_one = AutoTokenizer .from_pretrained (
796- args .pretrained_model_name_or_path , subfolder = "tokenizer" , revision = args .revision , use_fast = False
800+ args .pretrained_model_name_or_path ,
801+ subfolder = "tokenizer" ,
802+ revision = args .revision ,
803+ use_fast = False ,
797804 )
798805 tokenizer_two = AutoTokenizer .from_pretrained (
799- args .pretrained_model_name_or_path , subfolder = "tokenizer_2" , revision = args .revision , use_fast = False
806+ args .pretrained_model_name_or_path ,
807+ subfolder = "tokenizer_2" ,
808+ revision = args .revision ,
809+ use_fast = False ,
800810 )
801811
802812 # import correct text encoder classes
@@ -810,10 +820,10 @@ def main(args):
810820 # Load scheduler and models
811821 noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
812822 text_encoder_one = text_encoder_cls_one .from_pretrained (
813- args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
823+ args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision , variant = args . variant
814824 )
815825 text_encoder_two = text_encoder_cls_two .from_pretrained (
816- args .pretrained_model_name_or_path , subfolder = "text_encoder_2" , revision = args .revision
826+ args .pretrained_model_name_or_path , subfolder = "text_encoder_2" , revision = args .revision , variant = args . variant
817827 )
818828 vae_path = (
819829 args .pretrained_model_name_or_path
@@ -824,9 +834,10 @@ def main(args):
824834 vae_path ,
825835 subfolder = "vae" if args .pretrained_vae_model_name_or_path is None else None ,
826836 revision = args .revision ,
837+ variant = args .variant ,
827838 )
828839 unet = UNet2DConditionModel .from_pretrained (
829- args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision
840+ args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args . variant
830841 )
831842
832843 if args .controlnet_model_name_or_path :
0 commit comments