Skip to content

Commit 326de41

Browse files
authored
Trivial fix for undefined symbol in train_dreambooth.py (huggingface#1598)
easy fix for undefined name in train_dreambooth.py import_model_class_from_model_name_or_path loads a pretrained model and refers to args.revision in a context where args is undefined. I modified the function to take revision as an argument and modified the invocation of the function to pass in the revision from args. Seems like this was caused by a cut and paste.
1 parent eb1abee commit 326de41

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/dreambooth/train_dreambooth.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
logger = get_logger(__name__)
3131

3232

33-
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
33+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
3434
text_encoder_config = PretrainedConfig.from_pretrained(
3535
pretrained_model_name_or_path,
3636
subfolder="text_encoder",
37-
revision=args.revision,
37+
revision=revision,
3838
)
3939
model_class = text_encoder_config.architectures[0]
4040

@@ -469,7 +469,7 @@ def main(args):
469469
)
470470

471471
# import correct text encoder class
472-
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
472+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
473473

474474
# Load models and create wrapper for stable diffusion
475475
text_encoder = text_encoder_cls.from_pretrained(

0 commit comments

Comments
 (0)