Skip to content

Commit 29dfe22

Browse files
[advanced dreambooth lora sdxl training script] load pipeline for inference only if validation prompt is used (huggingface#6171)
* load pipeline for inference only if validation prompt is used * move things outside * load pipeline for inference only if validation prompt is used * fix readme when validation prompt is used --------- Co-authored-by: linoytsaban <[email protected]> Co-authored-by: apolinário <[email protected]>
1 parent 56806cd commit 29dfe22

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def save_model_card(
112112
repo_folder=None,
113113
vae_path=None,
114114
):
115-
img_str = "widget:\n" if images else ""
115+
img_str = "widget:\n"
116116
for i, image in enumerate(images):
117117
image.save(os.path.join(repo_folder, f"image_{i}.png"))
118118
img_str += f"""
@@ -121,6 +121,10 @@ def save_model_card(
121121
url:
122122
"image_{i}.png"
123123
"""
124+
if not images:
125+
img_str += f"""
126+
- text: '{instance_prompt}'
127+
"""
124128

125129
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
126130
diffusers_imports_pivotal = ""
@@ -157,8 +161,6 @@ def save_model_card(
157161
base_model: {base_model}
158162
instance_prompt: {instance_prompt}
159163
license: openrail++
160-
widget:
161-
- text: '{validation_prompt if validation_prompt else instance_prompt}'
162164
---
163165
"""
164166

@@ -2010,43 +2012,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
20102012
text_encoder_lora_layers=text_encoder_lora_layers,
20112013
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
20122014
)
2015+
images = []
2016+
if args.validation_prompt and args.num_validation_images > 0:
2017+
# Final inference
2018+
# Load previous pipeline
2019+
vae = AutoencoderKL.from_pretrained(
2020+
vae_path,
2021+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
2022+
revision=args.revision,
2023+
variant=args.variant,
2024+
torch_dtype=weight_dtype,
2025+
)
2026+
pipeline = StableDiffusionXLPipeline.from_pretrained(
2027+
args.pretrained_model_name_or_path,
2028+
vae=vae,
2029+
revision=args.revision,
2030+
variant=args.variant,
2031+
torch_dtype=weight_dtype,
2032+
)
20132033

2014-
# Final inference
2015-
# Load previous pipeline
2016-
vae = AutoencoderKL.from_pretrained(
2017-
vae_path,
2018-
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
2019-
revision=args.revision,
2020-
variant=args.variant,
2021-
torch_dtype=weight_dtype,
2022-
)
2023-
pipeline = StableDiffusionXLPipeline.from_pretrained(
2024-
args.pretrained_model_name_or_path,
2025-
vae=vae,
2026-
revision=args.revision,
2027-
variant=args.variant,
2028-
torch_dtype=weight_dtype,
2029-
)
2030-
2031-
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
2032-
scheduler_args = {}
2034+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
2035+
scheduler_args = {}
20332036

2034-
if "variance_type" in pipeline.scheduler.config:
2035-
variance_type = pipeline.scheduler.config.variance_type
2037+
if "variance_type" in pipeline.scheduler.config:
2038+
variance_type = pipeline.scheduler.config.variance_type
20362039

2037-
if variance_type in ["learned", "learned_range"]:
2038-
variance_type = "fixed_small"
2040+
if variance_type in ["learned", "learned_range"]:
2041+
variance_type = "fixed_small"
20392042

2040-
scheduler_args["variance_type"] = variance_type
2043+
scheduler_args["variance_type"] = variance_type
20412044

2042-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
2045+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
20432046

2044-
# load attention processors
2045-
pipeline.load_lora_weights(args.output_dir)
2047+
# load attention processors
2048+
pipeline.load_lora_weights(args.output_dir)
20462049

2047-
# run inference
2048-
images = []
2049-
if args.validation_prompt and args.num_validation_images > 0:
2050+
# run inference
20502051
pipeline = pipeline.to(accelerator.device)
20512052
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
20522053
images = [

0 commit comments

Comments
 (0)