Skip to content

Commit 4f14b36

Browse files
Full Dreambooth IF stage II upscaling (huggingface#3561)
* update dreambooth lora to work with IF stage II * Update dreambooth script for IF stage II upscaler
1 parent f751b88 commit 4f14b36

File tree

1 file changed

+46
-9
lines changed

1 file changed

+46
-9
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from diffusers.optimization import get_scheduler
5353
from diffusers.utils import check_min_version, is_wandb_available
5454
from diffusers.utils.import_utils import is_xformers_available
55+
from diffusers.utils.torch_utils import randn_tensor
5556

5657

5758
if is_wandb_available():
@@ -114,16 +115,17 @@ def log_validation(
114115

115116
pipeline_args = {}
116117

117-
if text_encoder is not None:
118-
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
119-
120118
if vae is not None:
121119
pipeline_args["vae"] = vae
122120

121+
if text_encoder is not None:
122+
text_encoder = accelerator.unwrap_model(text_encoder)
123+
123124
# create pipeline (note: unet and vae are loaded again in float32)
124125
pipeline = DiffusionPipeline.from_pretrained(
125126
args.pretrained_model_name_or_path,
126127
tokenizer=tokenizer,
128+
text_encoder=text_encoder,
127129
unet=accelerator.unwrap_model(unet),
128130
revision=args.revision,
129131
torch_dtype=weight_dtype,
@@ -156,10 +158,16 @@ def log_validation(
156158
# run inference
157159
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
158160
images = []
159-
for _ in range(args.num_validation_images):
160-
with torch.autocast("cuda"):
161-
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
162-
images.append(image)
161+
if args.validation_images is None:
162+
for _ in range(args.num_validation_images):
163+
with torch.autocast("cuda"):
164+
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
165+
images.append(image)
166+
else:
167+
for image in args.validation_images:
168+
image = Image.open(image)
169+
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
170+
images.append(image)
163171

164172
for tracker in accelerator.trackers:
165173
if tracker.name == "tensorboard":
@@ -525,6 +533,19 @@ def parse_args(input_args=None):
525533
parser.add_argument(
526534
"--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
527535
)
536+
parser.add_argument(
537+
"--validation_images",
538+
required=False,
539+
default=None,
540+
nargs="+",
541+
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
542+
)
543+
parser.add_argument(
544+
"--class_labels_conditioning",
545+
required=False,
546+
default=None,
547+
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
548+
)
528549

529550
if input_args is not None:
530551
args = parser.parse_args(input_args)
@@ -1169,7 +1190,7 @@ def compute_text_embeddings(prompt):
11691190
)
11701191
else:
11711192
noise = torch.randn_like(model_input)
1172-
bsz = model_input.shape[0]
1193+
bsz, channels, height, width = model_input.shape
11731194
# Sample a random timestep for each image
11741195
timesteps = torch.randint(
11751196
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
@@ -1191,8 +1212,24 @@ def compute_text_embeddings(prompt):
11911212
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
11921213
)
11931214

1215+
if unet.config.in_channels > channels:
1216+
needed_additional_channels = unet.config.in_channels - channels
1217+
additional_latents = randn_tensor(
1218+
(bsz, needed_additional_channels, height, width),
1219+
device=noisy_model_input.device,
1220+
dtype=noisy_model_input.dtype,
1221+
)
1222+
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
1223+
1224+
if args.class_labels_conditioning == "timesteps":
1225+
class_labels = timesteps
1226+
else:
1227+
class_labels = None
1228+
11941229
# Predict the noise residual
1195-
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
1230+
model_pred = unet(
1231+
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1232+
).sample
11961233

11971234
if model_pred.shape[1] == 6:
11981235
model_pred, _ = torch.chunk(model_pred, 2, dim=1)

0 commit comments

Comments
 (0)