Skip to content

Commit 6d2e19f

Browse files
[Examples] Allow downloading variant model files (huggingface#5531)
* add variant * add variant * Apply suggestions from code review * reformat * fix: textual_inversion.py * fix: variant in model_info --------- Co-authored-by: sayakpaul <[email protected]>
1 parent 2a7f43a commit 6d2e19f

16 files changed

+266
-77
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
8686
controlnet=controlnet,
8787
safety_checker=None,
8888
revision=args.revision,
89+
variant=args.variant,
8990
torch_dtype=weight_dtype,
9091
)
9192
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -249,10 +250,13 @@ def parse_args(input_args=None):
249250
type=str,
250251
default=None,
251252
required=False,
252-
help=(
253-
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
254-
" float32 precision."
255-
),
253+
help="Revision of pretrained model identifier from huggingface.co/models.",
254+
)
255+
parser.add_argument(
256+
"--variant",
257+
type=str,
258+
default=None,
259+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
256260
)
257261
parser.add_argument(
258262
"--tokenizer_name",
@@ -767,11 +771,13 @@ def main(args):
767771
# Load scheduler and models
768772
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
769773
text_encoder = text_encoder_cls.from_pretrained(
770-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
774+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
775+
)
776+
vae = AutoencoderKL.from_pretrained(
777+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
771778
)
772-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
773779
unet = UNet2DConditionModel.from_pretrained(
774-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
780+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
775781
)
776782

777783
if args.controlnet_model_name_or_path:

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
7474
unet=unet,
7575
controlnet=controlnet,
7676
revision=args.revision,
77+
variant=args.variant,
7778
torch_dtype=weight_dtype,
7879
)
7980
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -243,15 +244,18 @@ def parse_args(input_args=None):
243244
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
244245
" If not specified controlnet weights are initialized from unet.",
245246
)
247+
parser.add_argument(
248+
"--variant",
249+
type=str,
250+
default=None,
251+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
252+
)
246253
parser.add_argument(
247254
"--revision",
248255
type=str,
249256
default=None,
250257
required=False,
251-
help=(
252-
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
253-
" float32 precision."
254-
),
258+
help="Revision of pretrained model identifier from huggingface.co/models.",
255259
)
256260
parser.add_argument(
257261
"--tokenizer_name",
@@ -793,10 +797,16 @@ def main(args):
793797

794798
# Load the tokenizers
795799
tokenizer_one = AutoTokenizer.from_pretrained(
796-
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
800+
args.pretrained_model_name_or_path,
801+
subfolder="tokenizer",
802+
revision=args.revision,
803+
use_fast=False,
797804
)
798805
tokenizer_two = AutoTokenizer.from_pretrained(
799-
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
806+
args.pretrained_model_name_or_path,
807+
subfolder="tokenizer_2",
808+
revision=args.revision,
809+
use_fast=False,
800810
)
801811

802812
# import correct text encoder classes
@@ -810,10 +820,10 @@ def main(args):
810820
# Load scheduler and models
811821
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
812822
text_encoder_one = text_encoder_cls_one.from_pretrained(
813-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
823+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
814824
)
815825
text_encoder_two = text_encoder_cls_two.from_pretrained(
816-
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
826+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
817827
)
818828
vae_path = (
819829
args.pretrained_model_name_or_path
@@ -824,9 +834,10 @@ def main(args):
824834
vae_path,
825835
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
826836
revision=args.revision,
837+
variant=args.variant,
827838
)
828839
unet = UNet2DConditionModel.from_pretrained(
829-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
840+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
830841
)
831842

832843
if args.controlnet_model_name_or_path:

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,12 @@ def parse_args(input_args=None):
332332
required=False,
333333
help="Revision of pretrained model identifier from huggingface.co/models.",
334334
)
335+
parser.add_argument(
336+
"--variant",
337+
type=str,
338+
default=None,
339+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
340+
)
335341
parser.add_argument(
336342
"--tokenizer_name",
337343
type=str,
@@ -740,6 +746,7 @@ def main(args):
740746
torch_dtype=torch_dtype,
741747
safety_checker=None,
742748
revision=args.revision,
749+
variant=args.variant,
743750
)
744751
pipeline.set_progress_bar_config(disable=True)
745752

@@ -801,11 +808,13 @@ def main(args):
801808
# Load scheduler and models
802809
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
803810
text_encoder = text_encoder_cls.from_pretrained(
804-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
811+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
812+
)
813+
vae = AutoencoderKL.from_pretrained(
814+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
805815
)
806-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
807816
unet = UNet2DConditionModel.from_pretrained(
808-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
817+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
809818
)
810819

811820
# Adding a modifier token which is optimized ####
@@ -1229,6 +1238,7 @@ def main(args):
12291238
text_encoder=accelerator.unwrap_model(text_encoder),
12301239
tokenizer=tokenizer,
12311240
revision=args.revision,
1241+
variant=args.variant,
12321242
torch_dtype=weight_dtype,
12331243
)
12341244
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -1278,7 +1288,7 @@ def main(args):
12781288
# Final inference
12791289
# Load previous pipeline
12801290
pipeline = DiffusionPipeline.from_pretrained(
1281-
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
1291+
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
12821292
)
12831293
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
12841294
pipeline = pipeline.to(accelerator.device)

examples/dreambooth/train_dreambooth.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def log_validation(
139139
text_encoder=text_encoder,
140140
unet=accelerator.unwrap_model(unet),
141141
revision=args.revision,
142+
variant=args.variant,
142143
torch_dtype=weight_dtype,
143144
**pipeline_args,
144145
)
@@ -239,10 +240,13 @@ def parse_args(input_args=None):
239240
type=str,
240241
default=None,
241242
required=False,
242-
help=(
243-
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
244-
" float32 precision."
245-
),
243+
help="Revision of pretrained model identifier from huggingface.co/models.",
244+
)
245+
parser.add_argument(
246+
"--variant",
247+
type=str,
248+
default=None,
249+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
246250
)
247251
parser.add_argument(
248252
"--tokenizer_name",
@@ -859,6 +863,7 @@ def main(args):
859863
torch_dtype=torch_dtype,
860864
safety_checker=None,
861865
revision=args.revision,
866+
variant=args.variant,
862867
)
863868
pipeline.set_progress_bar_config(disable=True)
864869

@@ -912,18 +917,18 @@ def main(args):
912917
# Load scheduler and models
913918
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
914919
text_encoder = text_encoder_cls.from_pretrained(
915-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
920+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
916921
)
917922

918923
if model_has_vae(args):
919924
vae = AutoencoderKL.from_pretrained(
920-
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
925+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
921926
)
922927
else:
923928
vae = None
924929

925930
unet = UNet2DConditionModel.from_pretrained(
926-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
931+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
927932
)
928933

929934
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -1379,6 +1384,7 @@ def compute_text_embeddings(prompt):
13791384
args.pretrained_model_name_or_path,
13801385
unet=accelerator.unwrap_model(unet),
13811386
revision=args.revision,
1387+
variant=args.variant,
13821388
**pipeline_args,
13831389
)
13841390

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,15 +460,21 @@ def collate_fn(examples):
460460

461461
# Load models and create wrapper for stable diffusion
462462
text_encoder = FlaxCLIPTextModel.from_pretrained(
463-
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
463+
args.pretrained_model_name_or_path,
464+
subfolder="text_encoder",
465+
dtype=weight_dtype,
466+
revision=args.revision,
464467
)
465468
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
466469
vae_arg,
467470
dtype=weight_dtype,
468471
**vae_kwargs,
469472
)
470473
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
471-
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
474+
args.pretrained_model_name_or_path,
475+
subfolder="unet",
476+
dtype=weight_dtype,
477+
revision=args.revision,
472478
)
473479

474480
# Optimization

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@ def parse_args(input_args=None):
183183
required=False,
184184
help="Revision of pretrained model identifier from huggingface.co/models.",
185185
)
186+
parser.add_argument(
187+
"--variant",
188+
type=str,
189+
default=None,
190+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
191+
)
186192
parser.add_argument(
187193
"--tokenizer_name",
188194
type=str,
@@ -750,6 +756,7 @@ def main(args):
750756
torch_dtype=torch_dtype,
751757
safety_checker=None,
752758
revision=args.revision,
759+
variant=args.variant,
753760
)
754761
pipeline.set_progress_bar_config(disable=True)
755762

@@ -803,19 +810,19 @@ def main(args):
803810
# Load scheduler and models
804811
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
805812
text_encoder = text_encoder_cls.from_pretrained(
806-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
813+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
807814
)
808815
try:
809816
vae = AutoencoderKL.from_pretrained(
810-
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
817+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
811818
)
812819
except OSError:
813820
# IF does not have a VAE so let's just set it to None
814821
# We don't have to error out here
815822
vae = None
816823

817824
unet = UNet2DConditionModel.from_pretrained(
818-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
825+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
819826
)
820827

821828
# We only train the additional adapter LoRA layers
@@ -1310,6 +1317,7 @@ def compute_text_embeddings(prompt):
13101317
unet=accelerator.unwrap_model(unet),
13111318
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
13121319
revision=args.revision,
1320+
variant=args.variant,
13131321
torch_dtype=weight_dtype,
13141322
)
13151323

@@ -1395,7 +1403,7 @@ def compute_text_embeddings(prompt):
13951403
# Final inference
13961404
# Load previous pipeline
13971405
pipeline = DiffusionPipeline.from_pretrained(
1398-
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
1406+
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
13991407
)
14001408

14011409
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it

0 commit comments

Comments
 (0)