|
16 | 16 | from accelerate import Accelerator
|
17 | 17 | from accelerate.logging import get_logger
|
18 | 18 | from accelerate.utils import set_seed
|
19 |
| -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 19 | +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
20 | 20 | from diffusers.optimization import get_scheduler
|
| 21 | +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
21 | 22 | from huggingface_hub import HfFolder, Repository, whoami
|
22 | 23 |
|
23 | 24 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
24 | 25 | from packaging import version
|
25 | 26 | from PIL import Image
|
26 | 27 | from torchvision import transforms
|
27 | 28 | from tqdm.auto import tqdm
|
28 |
| -from transformers import CLIPTextModel, CLIPTokenizer |
| 29 | +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
29 | 30 |
|
30 | 31 |
|
31 | 32 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
@@ -65,6 +66,12 @@ def parse_args():
|
65 | 66 | default=500,
|
66 | 67 | help="Save learned_embeds.bin every X updates steps.",
|
67 | 68 | )
|
| 69 | + parser.add_argument( |
| 70 | + "--only_save_embeds", |
| 71 | + action="store_true", |
| 72 | + default=False, |
| 73 | + help="Save only the embeddings for the new concept.", |
| 74 | + ) |
68 | 75 | parser.add_argument(
|
69 | 76 | "--pretrained_model_name_or_path",
|
70 | 77 | type=str,
|
@@ -596,16 +603,23 @@ def main():
|
596 | 603 |
|
597 | 604 | # Create the pipeline using using the trained modules and save it.
|
598 | 605 | if accelerator.is_main_process:
|
599 |
| - pipeline = StableDiffusionPipeline.from_pretrained( |
600 |
| - args.pretrained_model_name_or_path, |
601 |
| - text_encoder=accelerator.unwrap_model(text_encoder), |
602 |
| - tokenizer=tokenizer, |
603 |
| - vae=vae, |
604 |
| - unet=unet, |
605 |
| - revision=args.revision, |
606 |
| - ) |
607 |
| - pipeline.save_pretrained(args.output_dir) |
608 |
| - # Also save the newly trained embeddings |
| 606 | + if args.push_to_hub and args.only_save_embeds: |
| 607 | + logger.warn("Enabling full model saving because --push_to_hub=True was specified.") |
| 608 | + save_full_model = True |
| 609 | + else: |
| 610 | + save_full_model = not args.only_save_embeds |
| 611 | + if save_full_model: |
| 612 | + pipeline = StableDiffusionPipeline( |
| 613 | + text_encoder=accelerator.unwrap_model(text_encoder), |
| 614 | + vae=vae, |
| 615 | + unet=unet, |
| 616 | + tokenizer=tokenizer, |
| 617 | + scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"), |
| 618 | + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), |
| 619 | + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
| 620 | + ) |
| 621 | + pipeline.save_pretrained(args.output_dir) |
| 622 | + # Save the newly trained embeddings |
609 | 623 | save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
610 | 624 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
611 | 625 |
|
|
0 commit comments