Skip to content

Commit 9d97440

Browse files
authored
[Easy] fix: save_model_card utility of the DreamBooth SDXL LoRA script (huggingface#7258)
* fix: save_model_card utility. * fix a little more to make it more lenient. * remove lower()
1 parent d9a3b69 commit 9d97440

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def save_model_card(
114114
)
115115

116116
model_description = f"""
117-
# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
117+
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
118118
119119
<Gallery />
120120
@@ -139,7 +139,7 @@ def save_model_card(
139139
[Download]({repo_id}/tree/main) them in the Files & versions tab.
140140
141141
"""
142-
if "playgroundai" in args.pretrained_model_name_or_path:
142+
if "playground" in base_model:
143143
model_description += """\n
144144
## License
145145
@@ -148,7 +148,7 @@ def save_model_card(
148148
model_card = load_or_create_model_card(
149149
repo_id_or_path=repo_id,
150150
from_training=True,
151-
license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community",
151+
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
152152
base_model=base_model,
153153
prompt=instance_prompt,
154154
model_description=model_description,
@@ -162,7 +162,7 @@ def save_model_card(
162162
"lora" if not use_dora else "dora",
163163
"template:sd-lora",
164164
]
165-
if "playgroundai" in base_model:
165+
if "playground" in base_model:
166166
tags.extend(["playground", "playground-diffusers"])
167167
else:
168168
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
@@ -206,7 +206,7 @@ def log_validation(
206206
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
207207
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
208208
inference_ctx = (
209-
contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
209+
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
210210
)
211211

212212
with inference_ctx:
@@ -1509,7 +1509,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15091509
if accelerator.is_main_process:
15101510
tracker_name = (
15111511
"dreambooth-lora-sd-xl"
1512-
if "playgroundai" not in args.pretrained_model_name_or_path
1512+
if "playground" not in args.pretrained_model_name_or_path
15131513
else "dreambooth-lora-playground"
15141514
)
15151515
accelerator.init_trackers(tracker_name, config=vars(args))

0 commit comments

Comments
 (0)