Skip to content

Commit d1bcbf3

Browse files
authored
[textual_inversion] Add an option for only saving the embeddings (huggingface#781)
[textual_inversion] Add an option to only save embeddings Add an command line option --only_save_embeds to the example script, for not saving the full model. Then only the learned embeddings are saved, which can be added to the original model at runtime in a similar way as they are created in the training script. Saving the full model is forced when --push_to_hub is used. (Implements huggingface#759)
1 parent df7cd5f commit d1bcbf3

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

examples/textual_inversion/textual_inversion.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616
from accelerate import Accelerator
1717
from accelerate.logging import get_logger
1818
from accelerate.utils import set_seed
19-
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
19+
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
2020
from diffusers.optimization import get_scheduler
21+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
2122
from huggingface_hub import HfFolder, Repository, whoami
2223

2324
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
2425
from packaging import version
2526
from PIL import Image
2627
from torchvision import transforms
2728
from tqdm.auto import tqdm
28-
from transformers import CLIPTextModel, CLIPTokenizer
29+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2930

3031

3132
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
@@ -65,6 +66,12 @@ def parse_args():
6566
default=500,
6667
help="Save learned_embeds.bin every X updates steps.",
6768
)
69+
parser.add_argument(
70+
"--only_save_embeds",
71+
action="store_true",
72+
default=False,
73+
help="Save only the embeddings for the new concept.",
74+
)
6875
parser.add_argument(
6976
"--pretrained_model_name_or_path",
7077
type=str,
@@ -596,16 +603,23 @@ def main():
596603

597604
# Create the pipeline using using the trained modules and save it.
598605
if accelerator.is_main_process:
599-
pipeline = StableDiffusionPipeline.from_pretrained(
600-
args.pretrained_model_name_or_path,
601-
text_encoder=accelerator.unwrap_model(text_encoder),
602-
tokenizer=tokenizer,
603-
vae=vae,
604-
unet=unet,
605-
revision=args.revision,
606-
)
607-
pipeline.save_pretrained(args.output_dir)
608-
# Also save the newly trained embeddings
606+
if args.push_to_hub and args.only_save_embeds:
607+
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
608+
save_full_model = True
609+
else:
610+
save_full_model = not args.only_save_embeds
611+
if save_full_model:
612+
pipeline = StableDiffusionPipeline(
613+
text_encoder=accelerator.unwrap_model(text_encoder),
614+
vae=vae,
615+
unet=unet,
616+
tokenizer=tokenizer,
617+
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
618+
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
619+
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
620+
)
621+
pipeline.save_pretrained(args.output_dir)
622+
# Save the newly trained embeddings
609623
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
610624
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
611625

0 commit comments

Comments
 (0)