|
34 | 34 | from accelerate import Accelerator |
35 | 35 | from accelerate.logging import get_logger |
36 | 36 | from accelerate.utils import set_seed |
37 | | -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 37 | +from diffusers import ( |
| 38 | + AutoencoderKL, |
| 39 | + DDPMScheduler, |
| 40 | + DiffusionPipeline, |
| 41 | + DPMSolverMultistepScheduler, |
| 42 | + StableDiffusionPipeline, |
| 43 | + UNet2DConditionModel, |
| 44 | +) |
38 | 45 | from diffusers.optimization import get_scheduler |
39 | | -from diffusers.utils import check_min_version |
| 46 | +from diffusers.utils import check_min_version, is_wandb_available |
40 | 47 | from diffusers.utils.import_utils import is_xformers_available |
41 | 48 | from huggingface_hub import HfFolder, Repository, create_repo, whoami |
42 | 49 |
|
@@ -250,6 +257,28 @@ def parse_args(): |
250 | 257 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
251 | 258 | ), |
252 | 259 | ) |
| 260 | + parser.add_argument( |
| 261 | + "--validation_prompt", |
| 262 | + type=str, |
| 263 | + default=None, |
| 264 | + help="A prompt that is used during validation to verify that the model is learning.", |
| 265 | + ) |
| 266 | + parser.add_argument( |
| 267 | + "--num_validation_images", |
| 268 | + type=int, |
| 269 | + default=4, |
| 270 | + help="Number of images that should be generated during validation with `validation_prompt`.", |
| 271 | + ) |
| 272 | + parser.add_argument( |
| 273 | + "--validation_epochs", |
| 274 | + type=int, |
| 275 | + default=50, |
| 276 | + help=( |
| 277 | + "Run validation every X epochs. Validation consists of running the prompt" |
| 278 | + " `args.validation_prompt` multiple times: `args.num_validation_images`" |
| 279 | + " and logging the images." |
| 280 | + ), |
| 281 | + ) |
253 | 282 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
254 | 283 | parser.add_argument( |
255 | 284 | "--checkpointing_steps", |
@@ -444,6 +473,11 @@ def main(): |
444 | 473 | logging_dir=logging_dir, |
445 | 474 | ) |
446 | 475 |
|
| 476 | + if args.report_to == "wandb": |
| 477 | + if not is_wandb_available(): |
| 478 | + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| 479 | + import wandb |
| 480 | + |
447 | 481 | # Make one log on every process with the configuration for debugging. |
448 | 482 | logging.basicConfig( |
449 | 483 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
@@ -740,6 +774,45 @@ def main(): |
740 | 774 | if global_step >= args.max_train_steps: |
741 | 775 | break |
742 | 776 |
|
| 777 | + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
| 778 | + logger.info( |
| 779 | + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| 780 | + f" {args.validation_prompt}." |
| 781 | + ) |
| 782 | + # create pipeline (note: unet and vae are loaded again in float32) |
| 783 | + pipeline = DiffusionPipeline.from_pretrained( |
| 784 | + args.pretrained_model_name_or_path, |
| 785 | + text_encoder=accelerator.unwrap_model(text_encoder), |
| 786 | + revision=args.revision, |
| 787 | + ) |
| 788 | + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 789 | + pipeline = pipeline.to(accelerator.device) |
| 790 | + pipeline.set_progress_bar_config(disable=True) |
| 791 | + |
| 792 | + # run inference |
| 793 | + generator = ( |
| 794 | + None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) |
| 795 | + ) |
| 796 | + prompt = args.num_validation_images * [args.validation_prompt] |
| 797 | + images = pipeline(prompt, num_inference_steps=25, generator=generator).images |
| 798 | + |
| 799 | + for tracker in accelerator.trackers: |
| 800 | + if tracker.name == "tensorboard": |
| 801 | + np_images = np.stack([np.asarray(img) for img in images]) |
| 802 | + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") |
| 803 | + if tracker.name == "wandb": |
| 804 | + tracker.log( |
| 805 | + { |
| 806 | + "validation": [ |
| 807 | + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") |
| 808 | + for i, image in enumerate(images) |
| 809 | + ] |
| 810 | + } |
| 811 | + ) |
| 812 | + |
| 813 | + del pipeline |
| 814 | + torch.cuda.empty_cache() |
| 815 | + |
743 | 816 | # Create the pipeline using using the trained modules and save it. |
744 | 817 | accelerator.wait_for_everyone() |
745 | 818 | if accelerator.is_main_process: |
|
0 commit comments