Skip to content

Commit f751b88

Browse files
update dreambooth lora to work with IF stage II (huggingface#3560)
1 parent abb89da commit f751b88

File tree

4 files changed

+59
-8
lines changed

4 files changed

+59
-8
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from diffusers.optimization import get_scheduler
6161
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
6262
from diffusers.utils.import_utils import is_xformers_available
63+
from diffusers.utils.torch_utils import randn_tensor
6364

6465

6566
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -425,6 +426,19 @@ def parse_args(input_args=None):
425426
required=False,
426427
help="Whether to use attention mask for the text encoder",
427428
)
429+
parser.add_argument(
430+
"--validation_images",
431+
required=False,
432+
default=None,
433+
nargs="+",
434+
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
435+
)
436+
parser.add_argument(
437+
"--class_labels_conditioning",
438+
required=False,
439+
default=None,
440+
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
441+
)
428442

429443
if input_args is not None:
430444
args = parser.parse_args(input_args)
@@ -1121,7 +1135,7 @@ def compute_text_embeddings(prompt):
11211135

11221136
# Sample noise that we'll add to the latents
11231137
noise = torch.randn_like(model_input)
1124-
bsz = model_input.shape[0]
1138+
bsz, channels, height, width = model_input.shape
11251139
# Sample a random timestep for each image
11261140
timesteps = torch.randint(
11271141
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
@@ -1143,8 +1157,24 @@ def compute_text_embeddings(prompt):
11431157
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
11441158
)
11451159

1160+
if unet.config.in_channels > channels:
1161+
needed_additional_channels = unet.config.in_channels - channels
1162+
additional_latents = randn_tensor(
1163+
(bsz, needed_additional_channels, height, width),
1164+
device=noisy_model_input.device,
1165+
dtype=noisy_model_input.dtype,
1166+
)
1167+
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
1168+
1169+
if args.class_labels_conditioning == "timesteps":
1170+
class_labels = timesteps
1171+
else:
1172+
class_labels = None
1173+
11461174
# Predict the noise residual
1147-
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
1175+
model_pred = unet(
1176+
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1177+
).sample
11481178

11491179
# if model predicts variance, throw away the prediction. we will only train on the
11501180
# simplified training objective. This means that all schedulers using the fine tuned
@@ -1248,9 +1278,18 @@ def compute_text_embeddings(prompt):
12481278
}
12491279
else:
12501280
pipeline_args = {"prompt": args.validation_prompt}
1251-
images = [
1252-
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
1253-
]
1281+
1282+
if args.validation_images is None:
1283+
images = [
1284+
pipeline(**pipeline_args, generator=generator).images[0]
1285+
for _ in range(args.num_validation_images)
1286+
]
1287+
else:
1288+
images = []
1289+
for image in args.validation_images:
1290+
image = Image.open(image)
1291+
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
1292+
images.append(image)
12541293

12551294
for tracker in accelerator.trackers:
12561295
if tracker.name == "tensorboard":

src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn.functional as F
1111
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
1212

13+
from ...loaders import LoraLoaderMixin
1314
from ...models import UNet2DConditionModel
1415
from ...schedulers import DDPMScheduler
1516
from ...utils import (
@@ -112,7 +113,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
112113
"""
113114

114115

115-
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
116+
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
116117
tokenizer: T5Tokenizer
117118
text_encoder: T5EncoderModel
118119

@@ -1047,6 +1048,9 @@ def __call__(
10471048
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
10481049
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
10491050

1051+
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
1052+
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
1053+
10501054
# compute the previous noisy sample x_t -> x_t-1
10511055
intermediate_images = self.scheduler.step(
10521056
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False

src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn.functional as F
1111
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
1212

13+
from ...loaders import LoraLoaderMixin
1314
from ...models import UNet2DConditionModel
1415
from ...schedulers import DDPMScheduler
1516
from ...utils import (
@@ -114,7 +115,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
114115
"""
115116

116117

117-
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
118+
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
118119
tokenizer: T5Tokenizer
119120
text_encoder: T5EncoderModel
120121

@@ -1154,6 +1155,9 @@ def __call__(
11541155
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
11551156
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
11561157

1158+
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
1159+
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
1160+
11571161
# compute the previous noisy sample x_t -> x_t-1
11581162
prev_intermediate_images = intermediate_images
11591163

src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn.functional as F
1111
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
1212

13+
from ...loaders import LoraLoaderMixin
1314
from ...models import UNet2DConditionModel
1415
from ...schedulers import DDPMScheduler
1516
from ...utils import (
@@ -70,7 +71,7 @@
7071
"""
7172

7273

73-
class IFSuperResolutionPipeline(DiffusionPipeline):
74+
class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
7475
tokenizer: T5Tokenizer
7576
text_encoder: T5EncoderModel
7677

@@ -903,6 +904,9 @@ def __call__(
903904
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
904905
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
905906

907+
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
908+
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
909+
906910
# compute the previous noisy sample x_t -> x_t-1
907911
intermediate_images = self.scheduler.step(
908912
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False

0 commit comments

Comments
 (0)