Skip to content

Commit 7081a25

Browse files
sayakpaulpcuenca
andauthored
[Examples] Multiple enhancements to the ControlNet training scripts (huggingface#7096)
* log_validation unification for controlnet. * additional fixes. * remove print. * better reuse and loading * make final inference run conditional. * Update examples/controlnet/README_sdxl.md Co-authored-by: Pedro Cuenca <[email protected]> * resize the control image in the snippet. --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 848f9fe commit 7081a25

File tree

3 files changed

+103
-22
lines changed

3 files changed

+103
-22
lines changed

examples/controlnet/README_sdxl.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ pipe.enable_xformers_memory_efficient_attention()
113113
# memory optimization.
114114
pipe.enable_model_cpu_offload()
115115

116-
control_image = load_image("./conditioning_image_1.png")
116+
control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
117117
prompt = "pale golden rod circle with old lace background"
118118

119119
# generate image
@@ -128,4 +128,14 @@ image.save("./output.png")
128128

129129
### Specifying a better VAE
130130

131-
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
131+
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
132+
133+
If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
134+
135+
```diff
136+
+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
137+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
138+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
139+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
140+
+ vae=vae,
141+
)

examples/controlnet/train_controlnet.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17+
import contextlib
18+
import gc
1719
import logging
1820
import math
1921
import os
@@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols):
7476
return grid
7577

7678

77-
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
79+
def log_validation(
80+
vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
81+
):
7882
logger.info("Running validation... ")
7983

80-
controlnet = accelerator.unwrap_model(controlnet)
84+
if not is_final_validation:
85+
controlnet = accelerator.unwrap_model(controlnet)
86+
else:
87+
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
8188

8289
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
8390
args.pretrained_model_name_or_path,
@@ -118,14 +125,15 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
118125
)
119126

120127
image_logs = []
128+
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
121129

122130
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
123131
validation_image = Image.open(validation_image).convert("RGB")
124132

125133
images = []
126134

127135
for _ in range(args.num_validation_images):
128-
with torch.autocast("cuda"):
136+
with inference_ctx:
129137
image = pipeline(
130138
validation_prompt, validation_image, num_inference_steps=20, generator=generator
131139
).images[0]
@@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
136144
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
137145
)
138146

147+
tracker_key = "test" if is_final_validation else "validation"
139148
for tracker in accelerator.trackers:
140149
if tracker.name == "tensorboard":
141150
for log in image_logs:
@@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
167176
image = wandb.Image(image, caption=validation_prompt)
168177
formatted_images.append(image)
169178

170-
tracker.log({"validation": formatted_images})
179+
tracker.log({tracker_key: formatted_images})
171180
else:
172181
logger.warn(f"image logging not implemented for {tracker.name}")
173182

183+
del pipeline
184+
gc.collect()
185+
torch.cuda.empty_cache()
186+
174187
return image_logs
175188

176189

@@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
197210
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
198211
img_str = ""
199212
if image_logs is not None:
200-
img_str = "You can find some example images below.\n"
213+
img_str = "You can find some example images below.\n\n"
201214
for i, log in enumerate(image_logs):
202215
images = log["images"]
203216
validation_prompt = log["validation_prompt"]
@@ -1131,6 +1144,22 @@ def load_model_hook(models, input_dir):
11311144
controlnet = unwrap_model(controlnet)
11321145
controlnet.save_pretrained(args.output_dir)
11331146

1147+
# Run a final round of validation.
1148+
image_logs = None
1149+
if args.validation_prompt is not None:
1150+
image_logs = log_validation(
1151+
vae=vae,
1152+
text_encoder=text_encoder,
1153+
tokenizer=tokenizer,
1154+
unet=unet,
1155+
controlnet=None,
1156+
args=args,
1157+
accelerator=accelerator,
1158+
weight_dtype=weight_dtype,
1159+
step=global_step,
1160+
is_final_validation=True,
1161+
)
1162+
11341163
if args.push_to_hub:
11351164
save_model_card(
11361165
repo_id,

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515

1616
import argparse
17+
import contextlib
1718
import functools
1819
import gc
1920
import logging
@@ -65,20 +66,38 @@
6566
logger = get_logger(__name__)
6667

6768

68-
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
69+
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
6970
logger.info("Running validation... ")
7071

71-
controlnet = accelerator.unwrap_model(controlnet)
72+
if not is_final_validation:
73+
controlnet = accelerator.unwrap_model(controlnet)
74+
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
75+
args.pretrained_model_name_or_path,
76+
vae=vae,
77+
unet=unet,
78+
controlnet=controlnet,
79+
revision=args.revision,
80+
variant=args.variant,
81+
torch_dtype=weight_dtype,
82+
)
83+
else:
84+
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
85+
if args.pretrained_vae_model_name_or_path is not None:
86+
vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
87+
else:
88+
vae = AutoencoderKL.from_pretrained(
89+
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
90+
)
91+
92+
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
93+
args.pretrained_model_name_or_path,
94+
vae=vae,
95+
controlnet=controlnet,
96+
revision=args.revision,
97+
variant=args.variant,
98+
torch_dtype=weight_dtype,
99+
)
72100

73-
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
74-
args.pretrained_model_name_or_path,
75-
vae=vae,
76-
unet=unet,
77-
controlnet=controlnet,
78-
revision=args.revision,
79-
variant=args.variant,
80-
torch_dtype=weight_dtype,
81-
)
82101
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
83102
pipeline = pipeline.to(accelerator.device)
84103
pipeline.set_progress_bar_config(disable=True)
@@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
106125
)
107126

108127
image_logs = []
128+
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
109129

110130
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
111131
validation_image = Image.open(validation_image).convert("RGB")
@@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
114134
images = []
115135

116136
for _ in range(args.num_validation_images):
117-
with torch.autocast("cuda"):
137+
with inference_ctx:
118138
image = pipeline(
119139
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
120140
).images[0]
@@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
124144
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
125145
)
126146

147+
tracker_key = "test" if is_final_validation else "validation"
127148
for tracker in accelerator.trackers:
128149
if tracker.name == "tensorboard":
129150
for log in image_logs:
@@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
155176
image = wandb.Image(image, caption=validation_prompt)
156177
formatted_images.append(image)
157178

158-
tracker.log({"validation": formatted_images})
179+
tracker.log({tracker_key: formatted_images})
159180
else:
160181
logger.warn(f"image logging not implemented for {tracker.name}")
161182

@@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path(
189210
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
190211
img_str = ""
191212
if image_logs is not None:
192-
img_str = "You can find some example images below.\n"
213+
img_str = "You can find some example images below.\n\n"
193214
for i, log in enumerate(image_logs):
194215
images = log["images"]
195216
validation_prompt = log["validation_prompt"]
@@ -1228,7 +1249,13 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
12281249

12291250
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
12301251
image_logs = log_validation(
1231-
vae, unet, controlnet, args, accelerator, weight_dtype, global_step
1252+
vae=vae,
1253+
unet=unet,
1254+
controlnet=controlnet,
1255+
args=args,
1256+
accelerator=accelerator,
1257+
weight_dtype=weight_dtype,
1258+
step=global_step,
12321259
)
12331260

12341261
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -1244,6 +1271,21 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
12441271
controlnet = unwrap_model(controlnet)
12451272
controlnet.save_pretrained(args.output_dir)
12461273

1274+
# Run a final round of validation.
1275+
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1276+
image_logs = None
1277+
if args.validation_prompt is not None:
1278+
image_logs = log_validation(
1279+
vae=None,
1280+
unet=None,
1281+
controlnet=None,
1282+
args=args,
1283+
accelerator=accelerator,
1284+
weight_dtype=weight_dtype,
1285+
step=global_step,
1286+
is_final_validation=True,
1287+
)
1288+
12471289
if args.push_to_hub:
12481290
save_model_card(
12491291
repo_id,

0 commit comments

Comments
 (0)