Skip to content

Commit 2f9a70a

Browse files
[LoRA] Make sure validation works in multi GPU setup (huggingface#2172)
* [LoRA] Make sure validation works in multi GPU setup * more fixes * up
1 parent e43e206 commit 2f9a70a

File tree

2 files changed

+86
-78
lines changed

2 files changed

+86
-78
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -923,44 +923,47 @@ def main(args):
923923
if global_step >= args.max_train_steps:
924924
break
925925

926-
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
927-
logger.info(
928-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
929-
f" {args.validation_prompt}."
930-
)
931-
# create pipeline
932-
pipeline = DiffusionPipeline.from_pretrained(
933-
args.pretrained_model_name_or_path,
934-
unet=accelerator.unwrap_model(unet),
935-
text_encoder=accelerator.unwrap_model(text_encoder),
936-
revision=args.revision,
937-
torch_dtype=weight_dtype,
938-
)
939-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
940-
pipeline = pipeline.to(accelerator.device)
941-
pipeline.set_progress_bar_config(disable=True)
942-
943-
# run inference
944-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
945-
prompt = args.num_validation_images * [args.validation_prompt]
946-
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
947-
948-
for tracker in accelerator.trackers:
949-
if tracker.name == "tensorboard":
950-
np_images = np.stack([np.asarray(img) for img in images])
951-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
952-
if tracker.name == "wandb":
953-
tracker.log(
954-
{
955-
"validation": [
956-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
957-
for i, image in enumerate(images)
958-
]
959-
}
960-
)
961-
962-
del pipeline
963-
torch.cuda.empty_cache()
926+
if accelerator.is_main_process:
927+
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
928+
logger.info(
929+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
930+
f" {args.validation_prompt}."
931+
)
932+
# create pipeline
933+
pipeline = DiffusionPipeline.from_pretrained(
934+
args.pretrained_model_name_or_path,
935+
unet=accelerator.unwrap_model(unet),
936+
text_encoder=accelerator.unwrap_model(text_encoder),
937+
revision=args.revision,
938+
torch_dtype=weight_dtype,
939+
)
940+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
941+
pipeline = pipeline.to(accelerator.device)
942+
pipeline.set_progress_bar_config(disable=True)
943+
944+
# run inference
945+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
946+
images = [
947+
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
948+
for _ in range(args.num_validation_images)
949+
]
950+
951+
for tracker in accelerator.trackers:
952+
if tracker.name == "tensorboard":
953+
np_images = np.stack([np.asarray(img) for img in images])
954+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
955+
if tracker.name == "wandb":
956+
tracker.log(
957+
{
958+
"validation": [
959+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
960+
for i, image in enumerate(images)
961+
]
962+
}
963+
)
964+
965+
del pipeline
966+
torch.cuda.empty_cache()
964967

965968
# Save the lora layers
966969
accelerator.wait_for_everyone()
@@ -982,8 +985,10 @@ def main(args):
982985
# run inference
983986
if args.validation_prompt and args.num_validation_images > 0:
984987
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
985-
prompt = args.num_validation_images * [args.validation_prompt]
986-
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
988+
images = [
989+
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
990+
for _ in range(args.num_validation_images)
991+
]
987992

988993
for tracker in accelerator.trackers:
989994
if tracker.name == "tensorboard":

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -749,44 +749,47 @@ def collate_fn(examples):
749749
if global_step >= args.max_train_steps:
750750
break
751751

752-
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
753-
logger.info(
754-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
755-
f" {args.validation_prompt}."
756-
)
757-
# create pipeline
758-
pipeline = DiffusionPipeline.from_pretrained(
759-
args.pretrained_model_name_or_path,
760-
unet=accelerator.unwrap_model(unet),
761-
revision=args.revision,
762-
torch_dtype=weight_dtype,
763-
)
764-
pipeline = pipeline.to(accelerator.device)
765-
pipeline.set_progress_bar_config(disable=True)
766-
767-
# run inference
768-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
769-
images = []
770-
for _ in range(args.num_validation_images):
771-
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
772-
773-
if accelerator.is_main_process:
774-
for tracker in accelerator.trackers:
775-
if tracker.name == "tensorboard":
776-
np_images = np.stack([np.asarray(img) for img in images])
777-
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
778-
if tracker.name == "wandb":
779-
tracker.log(
780-
{
781-
"validation": [
782-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
783-
for i, image in enumerate(images)
784-
]
785-
}
786-
)
787-
788-
del pipeline
789-
torch.cuda.empty_cache()
752+
if accelerator.is_main_process:
753+
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
754+
logger.info(
755+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
756+
f" {args.validation_prompt}."
757+
)
758+
# create pipeline
759+
pipeline = DiffusionPipeline.from_pretrained(
760+
args.pretrained_model_name_or_path,
761+
unet=accelerator.unwrap_model(unet),
762+
revision=args.revision,
763+
torch_dtype=weight_dtype,
764+
)
765+
pipeline = pipeline.to(accelerator.device)
766+
pipeline.set_progress_bar_config(disable=True)
767+
768+
# run inference
769+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
770+
images = []
771+
for _ in range(args.num_validation_images):
772+
images.append(
773+
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
774+
)
775+
776+
if accelerator.is_main_process:
777+
for tracker in accelerator.trackers:
778+
if tracker.name == "tensorboard":
779+
np_images = np.stack([np.asarray(img) for img in images])
780+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
781+
if tracker.name == "wandb":
782+
tracker.log(
783+
{
784+
"validation": [
785+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
786+
for i, image in enumerate(images)
787+
]
788+
}
789+
)
790+
791+
del pipeline
792+
torch.cuda.empty_cache()
790793

791794
# Save the lora layers
792795
accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)