Skip to content

Commit d29d97b

Browse files
authored
[examples/advanced_diffusion_training] bug fixes and improvements for LoRA Dreambooth SDXL advanced training script (huggingface#5935)
* imports and readme bug fixes * bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16 * added pivotal tuning to readme * mapping token identifier to new inserted token in validation prompt (if used) * correct default value of --train_text_encoder_frac * change default value of --adam_weight_decay_text_encoder * validation prompt generations when using pivotal tuning bug fix * style fix * textual inversion embeddings name change * style fix * bug fix - stopping text encoder optimization halfway * readme - will include token abstraction and new inserted tokens when using pivotal tuning - added type to --num_new_tokens_per_abstraction * style fix --------- Co-authored-by: Linoy Tsaban <[email protected]>
1 parent 7d4a257 commit d29d97b

File tree

1 file changed

+92
-56
lines changed

1 file changed

+92
-56
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
UNet2DConditionModel,
5555
)
5656
from diffusers.loaders import LoraLoaderMixin
57-
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
57+
from diffusers.models.lora import LoRALinearLayer
5858
from diffusers.optimization import get_scheduler
5959
from diffusers.training_utils import compute_snr, unet_lora_state_dict
6060
from diffusers.utils import check_min_version, is_wandb_available
@@ -67,11 +67,46 @@
6767
logger = get_logger(__name__)
6868

6969

70+
# TODO: This function should be removed once training scripts are rewritten in PEFT
71+
def text_encoder_lora_state_dict(text_encoder):
72+
state_dict = {}
73+
74+
def text_encoder_attn_modules(text_encoder):
75+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
76+
77+
attn_modules = []
78+
79+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
80+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
81+
name = f"text_model.encoder.layers.{i}.self_attn"
82+
mod = layer.self_attn
83+
attn_modules.append((name, mod))
84+
85+
return attn_modules
86+
87+
for name, module in text_encoder_attn_modules(text_encoder):
88+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
89+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
90+
91+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
92+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
93+
94+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
95+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
96+
97+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
98+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
99+
100+
return state_dict
101+
102+
70103
def save_model_card(
71104
repo_id: str,
72105
images=None,
73106
base_model=str,
74107
train_text_encoder=False,
108+
train_text_encoder_ti=False,
109+
token_abstraction_dict=None,
75110
instance_prompt=str,
76111
validation_prompt=str,
77112
repo_folder=None,
@@ -83,10 +118,23 @@ def save_model_card(
83118
img_str += f"""
84119
- text: '{validation_prompt if validation_prompt else ' ' }'
85120
output:
86-
url: >-
121+
url:
87122
"image_{i}.png"
88123
"""
89124

125+
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
126+
if train_text_encoder_ti:
127+
trigger_str = (
128+
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
129+
"in you prompt with the new inserted tokens:\n"
130+
)
131+
if token_abstraction_dict:
132+
for key, value in token_abstraction_dict.items():
133+
tokens = "".join(value)
134+
trigger_str += f"""
135+
to trigger concept {key}-> use {tokens} in your prompt \n
136+
"""
137+
90138
yaml = f"""
91139
---
92140
tags:
@@ -96,9 +144,7 @@ def save_model_card(
96144
- diffusers
97145
- lora
98146
- template:sd-lora
99-
widget:
100147
{img_str}
101-
---
102148
base_model: {base_model}
103149
instance_prompt: {instance_prompt}
104150
license: openrail++
@@ -112,14 +158,19 @@ def save_model_card(
112158
113159
## Model description
114160
115-
These are {repo_id} LoRA adaption weights for {base_model}.
161+
### These are {repo_id} LoRA adaption weights for {base_model}.
162+
116163
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
164+
117165
LoRA for the text encoder was enabled: {train_text_encoder}.
166+
167+
Pivotal tuning was enabled: {train_text_encoder_ti}.
168+
118169
Special VAE used for training: {vae_path}.
119170
120171
## Trigger words
121172
122-
You should use {instance_prompt} to trigger the image generation.
173+
{trigger_str}
123174
124175
## Download model
125176
@@ -244,6 +295,7 @@ def parse_args(input_args=None):
244295

245296
parser.add_argument(
246297
"--num_new_tokens_per_abstraction",
298+
type=int,
247299
default=2,
248300
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
249301
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
@@ -455,7 +507,7 @@ def parse_args(input_args=None):
455507
parser.add_argument(
456508
"--train_text_encoder_frac",
457509
type=float,
458-
default=0.5,
510+
default=1.0,
459511
help=("The percentage of epochs to perform text encoder tuning"),
460512
)
461513

@@ -488,7 +540,7 @@ def parse_args(input_args=None):
488540
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
489541
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
490542
parser.add_argument(
491-
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
543+
"--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
492544
)
493545

494546
parser.add_argument(
@@ -679,12 +731,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
679731
def save_embeddings(self, file_path: str):
680732
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
681733
tensors = {}
734+
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
735+
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
682736
for idx, text_encoder in enumerate(self.text_encoders):
683737
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
684738
self.tokenizers[0]
685739
), "Tokenizers should be the same."
686740
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
687-
tensors[f"text_encoders_{idx}"] = new_token_embeddings
741+
742+
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
743+
# text_encoder 1) to keep compatible with the ecosystem.
744+
# Note: When loading with diffusers, any name can work - simply specify in inference
745+
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
746+
# tensors[f"text_encoders_{idx}"] = new_token_embeddings
688747

689748
save_file(tensors, file_path)
690749

@@ -696,19 +755,6 @@ def dtype(self):
696755
def device(self):
697756
return self.text_encoders[0].device
698757

699-
# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
700-
# # Assuming new tokens are of the format <s_i>
701-
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
702-
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
703-
# tokenizer.add_special_tokens(special_tokens_dict)
704-
# text_encoder.resize_token_embeddings(len(tokenizer))
705-
#
706-
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
707-
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
708-
# text_encoder.text_model.embeddings.token_embedding.weight.data[
709-
# self.train_ids
710-
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
711-
712758
@torch.no_grad()
713759
def retract_embeddings(self):
714760
for idx, text_encoder in enumerate(self.text_encoders):
@@ -730,15 +776,6 @@ def retract_embeddings(self):
730776
new_embeddings = new_embeddings * (off_ratio**0.1)
731777
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
732778

733-
# def load_embeddings(self, file_path: str):
734-
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
735-
# for idx in range(len(self.text_encoders)):
736-
# text_encoder = self.text_encoders[idx]
737-
# tokenizer = self.tokenizers[idx]
738-
#
739-
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
740-
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
741-
742779

743780
class DreamBoothDataset(Dataset):
744781
"""
@@ -1216,13 +1253,17 @@ def main(args):
12161253
text_lora_parameters_one = []
12171254
for name, param in text_encoder_one.named_parameters():
12181255
if "token_embedding" in name:
1256+
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1257+
param = param.to(dtype=torch.float32)
12191258
param.requires_grad = True
12201259
text_lora_parameters_one.append(param)
12211260
else:
12221261
param.requires_grad = False
12231262
text_lora_parameters_two = []
12241263
for name, param in text_encoder_two.named_parameters():
12251264
if "token_embedding" in name:
1265+
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1266+
param = param.to(dtype=torch.float32)
12261267
param.requires_grad = True
12271268
text_lora_parameters_two.append(param)
12281269
else:
@@ -1309,12 +1350,16 @@ def load_model_hook(models, input_dir):
13091350
# different learning rate for text encoder and unet
13101351
text_lora_parameters_one_with_lr = {
13111352
"params": text_lora_parameters_one,
1312-
"weight_decay": args.adam_weight_decay_text_encoder,
1353+
"weight_decay": args.adam_weight_decay_text_encoder
1354+
if args.adam_weight_decay_text_encoder
1355+
else args.adam_weight_decay,
13131356
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
13141357
}
13151358
text_lora_parameters_two_with_lr = {
13161359
"params": text_lora_parameters_two,
1317-
"weight_decay": args.adam_weight_decay_text_encoder,
1360+
"weight_decay": args.adam_weight_decay_text_encoder
1361+
if args.adam_weight_decay_text_encoder
1362+
else args.adam_weight_decay,
13181363
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
13191364
}
13201365
params_to_optimize = [
@@ -1494,6 +1539,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14941539
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
14951540
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
14961541

1542+
if args.train_text_encoder_ti and args.validation_prompt:
1543+
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
1544+
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
1545+
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
1546+
print("validation prompt:", args.validation_prompt)
1547+
14971548
# Scheduler and math around the number of training steps.
14981549
overrode_max_train_steps = False
14991550
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1593,27 +1644,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15931644
if epoch == num_train_epochs_text_encoder:
15941645
print("PIVOT HALFWAY", epoch)
15951646
# stopping optimization of text_encoder params
1596-
params_to_optimize = params_to_optimize[:1]
1597-
# reinitializing the optimizer to optimize only on unet params
1598-
if args.optimizer.lower() == "prodigy":
1599-
optimizer = optimizer_class(
1600-
params_to_optimize,
1601-
lr=args.learning_rate,
1602-
betas=(args.adam_beta1, args.adam_beta2),
1603-
beta3=args.prodigy_beta3,
1604-
weight_decay=args.adam_weight_decay,
1605-
eps=args.adam_epsilon,
1606-
decouple=args.prodigy_decouple,
1607-
use_bias_correction=args.prodigy_use_bias_correction,
1608-
safeguard_warmup=args.prodigy_safeguard_warmup,
1609-
)
1610-
else: # AdamW or 8-bit-AdamW
1611-
optimizer = optimizer_class(
1612-
params_to_optimize,
1613-
betas=(args.adam_beta1, args.adam_beta2),
1614-
weight_decay=args.adam_weight_decay,
1615-
eps=args.adam_epsilon,
1616-
)
1647+
# re setting the optimizer to optimize only on unet params
1648+
optimizer.param_groups[1]["lr"] = 0.0
1649+
optimizer.param_groups[2]["lr"] = 0.0
1650+
16171651
else:
16181652
# still optimizng the text encoder
16191653
text_encoder_one.train()
@@ -1628,7 +1662,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16281662
with accelerator.accumulate(unet):
16291663
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
16301664
prompts = batch["prompts"]
1631-
print(prompts)
1665+
# print(prompts)
16321666
# encode batch prompts when custom prompts are provided for each image -
16331667
if train_dataset.custom_instance_prompts:
16341668
if freeze_text_encoder:
@@ -1801,7 +1835,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18011835
f" {args.validation_prompt}."
18021836
)
18031837
# create pipeline
1804-
if not args.train_text_encoder:
1838+
if freeze_text_encoder:
18051839
text_encoder_one = text_encoder_cls_one.from_pretrained(
18061840
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
18071841
)
@@ -1948,6 +1982,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19481982
images=images,
19491983
base_model=args.pretrained_model_name_or_path,
19501984
train_text_encoder=args.train_text_encoder,
1985+
train_text_encoder_ti=args.train_text_encoder_ti,
1986+
token_abstraction_dict=train_dataset.token_abstraction_dict,
19511987
instance_prompt=args.instance_prompt,
19521988
validation_prompt=args.validation_prompt,
19531989
repo_folder=args.output_dir,

0 commit comments

Comments
 (0)