Skip to content

Commit 0856137

Browse files
[textual inversion] Allow validation images (huggingface#2077)
* [textual inversion] Allow validation images. * Change key to `validation` * Specify format instead of transposing. As discussed with @sayakpaul. * Style Co-authored-by: isamu-isozaki <[email protected]>
1 parent 946d1cb commit 0856137

File tree

1 file changed

+75
-2
lines changed

1 file changed

+75
-2
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,16 @@
3434
from accelerate import Accelerator
3535
from accelerate.logging import get_logger
3636
from accelerate.utils import set_seed
37-
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
37+
from diffusers import (
38+
AutoencoderKL,
39+
DDPMScheduler,
40+
DiffusionPipeline,
41+
DPMSolverMultistepScheduler,
42+
StableDiffusionPipeline,
43+
UNet2DConditionModel,
44+
)
3845
from diffusers.optimization import get_scheduler
39-
from diffusers.utils import check_min_version
46+
from diffusers.utils import check_min_version, is_wandb_available
4047
from diffusers.utils.import_utils import is_xformers_available
4148
from huggingface_hub import HfFolder, Repository, create_repo, whoami
4249

@@ -250,6 +257,28 @@ def parse_args():
250257
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
251258
),
252259
)
260+
parser.add_argument(
261+
"--validation_prompt",
262+
type=str,
263+
default=None,
264+
help="A prompt that is used during validation to verify that the model is learning.",
265+
)
266+
parser.add_argument(
267+
"--num_validation_images",
268+
type=int,
269+
default=4,
270+
help="Number of images that should be generated during validation with `validation_prompt`.",
271+
)
272+
parser.add_argument(
273+
"--validation_epochs",
274+
type=int,
275+
default=50,
276+
help=(
277+
"Run validation every X epochs. Validation consists of running the prompt"
278+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
279+
" and logging the images."
280+
),
281+
)
253282
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
254283
parser.add_argument(
255284
"--checkpointing_steps",
@@ -444,6 +473,11 @@ def main():
444473
logging_dir=logging_dir,
445474
)
446475

476+
if args.report_to == "wandb":
477+
if not is_wandb_available():
478+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
479+
import wandb
480+
447481
# Make one log on every process with the configuration for debugging.
448482
logging.basicConfig(
449483
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -740,6 +774,45 @@ def main():
740774
if global_step >= args.max_train_steps:
741775
break
742776

777+
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
778+
logger.info(
779+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
780+
f" {args.validation_prompt}."
781+
)
782+
# create pipeline (note: unet and vae are loaded again in float32)
783+
pipeline = DiffusionPipeline.from_pretrained(
784+
args.pretrained_model_name_or_path,
785+
text_encoder=accelerator.unwrap_model(text_encoder),
786+
revision=args.revision,
787+
)
788+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
789+
pipeline = pipeline.to(accelerator.device)
790+
pipeline.set_progress_bar_config(disable=True)
791+
792+
# run inference
793+
generator = (
794+
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
795+
)
796+
prompt = args.num_validation_images * [args.validation_prompt]
797+
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
798+
799+
for tracker in accelerator.trackers:
800+
if tracker.name == "tensorboard":
801+
np_images = np.stack([np.asarray(img) for img in images])
802+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
803+
if tracker.name == "wandb":
804+
tracker.log(
805+
{
806+
"validation": [
807+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
808+
for i, image in enumerate(images)
809+
]
810+
}
811+
)
812+
813+
del pipeline
814+
torch.cuda.empty_cache()
815+
743816
# Create the pipeline using using the trained modules and save it.
744817
accelerator.wait_for_everyone()
745818
if accelerator.is_main_process:

0 commit comments

Comments
 (0)