Skip to content

Commit 880c0fd

Browse files
authored
[advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (huggingface#6040)
* improve help tags * style fix * changes token_abstraction type to string. support multiple concepts for pivotal using a comma separated string. * style fixup * changed logger to warning (not yet available) * moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token --------- Co-authored-by: Linoy <[email protected]>
1 parent c36f1c3 commit 880c0fd

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,18 @@ def parse_args(input_args=None):
300300
)
301301
parser.add_argument(
302302
"--token_abstraction",
303+
type=str,
303304
default="TOK",
304305
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
305-
"captions - e.g. TOK",
306+
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
307+
"'TOK,TOK2,TOK3' etc.",
306308
)
307309

308310
parser.add_argument(
309311
"--num_new_tokens_per_abstraction",
310312
type=int,
311313
default=2,
312-
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
314+
help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
313315
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
314316
"tokens - <si><si+1> ",
315317
)
@@ -660,17 +662,6 @@ def parse_args(input_args=None):
660662
"inversion training check `--train_text_encoder_ti`"
661663
)
662664

663-
if args.train_text_encoder_ti:
664-
if isinstance(args.token_abstraction, str):
665-
args.token_abstraction = [args.token_abstraction]
666-
elif isinstance(args.token_abstraction, List):
667-
args.token_abstraction = args.token_abstraction
668-
else:
669-
raise ValueError(
670-
f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
671-
f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
672-
)
673-
674665
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
675666
if env_local_rank != -1 and env_local_rank != args.local_rank:
676667
args.local_rank = env_local_rank
@@ -1155,9 +1146,14 @@ def main(args):
11551146
)
11561147

11571148
if args.train_text_encoder_ti:
1149+
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
1150+
# TOK2" -> ["TOK", "TOK2"] etc.
1151+
token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
1152+
logger.info(f"list of token identifiers: {token_abstraction_list}")
1153+
11581154
token_abstraction_dict = {}
11591155
token_idx = 0
1160-
for i, token in enumerate(args.token_abstraction):
1156+
for i, token in enumerate(token_abstraction_list):
11611157
token_abstraction_dict[token] = [
11621158
f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction)
11631159
]

0 commit comments

Comments
 (0)