Skip to content

Commit ac4c695

Browse files
authored
[Flax examples] Load text encoder from subfolder (huggingface#1147)
load text encoder from subfolder
1 parent 0173323 commit ac4c695

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,9 @@ def collate_fn(examples):
452452
weight_dtype = jnp.bfloat16
453453

454454
# Load models and create wrapper for stable diffusion
455-
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype)
455+
text_encoder = FlaxCLIPTextModel.from_pretrained(
456+
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
457+
)
456458
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
457459
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
458460
)

examples/text_to_image/train_text_to_image_flax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,9 @@ def collate_fn(examples):
379379

380380
# Load models and create wrapper for stable diffusion
381381
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
382-
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype)
382+
text_encoder = FlaxCLIPTextModel.from_pretrained(
383+
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
384+
)
383385
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
384386
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
385387
)

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def main():
391391
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
392392

393393
# Load models and create wrapper for stable diffusion
394-
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
394+
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
395395
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
396396
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
397397

0 commit comments

Comments
 (0)