Skip to content

Commit d73e6ad

Browse files
guard save model hooks to only execute on main process (huggingface#4929)
1 parent d0cf681 commit d73e6ad

File tree

14 files changed

+147
-133
lines changed

14 files changed

+147
-133
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -785,16 +785,17 @@ def main(args):
785785
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
786786
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
787787
def save_model_hook(models, weights, output_dir):
788-
i = len(weights) - 1
788+
if accelerator.is_main_process:
789+
i = len(weights) - 1
789790

790-
while len(weights) > 0:
791-
weights.pop()
792-
model = models[i]
791+
while len(weights) > 0:
792+
weights.pop()
793+
model = models[i]
793794

794-
sub_dir = "controlnet"
795-
model.save_pretrained(os.path.join(output_dir, sub_dir))
795+
sub_dir = "controlnet"
796+
model.save_pretrained(os.path.join(output_dir, sub_dir))
796797

797-
i -= 1
798+
i -= 1
798799

799800
def load_model_hook(models, input_dir):
800801
while len(models) > 0:

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -840,16 +840,17 @@ def main(args):
840840
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
841841
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
842842
def save_model_hook(models, weights, output_dir):
843-
i = len(weights) - 1
843+
if accelerator.is_main_process:
844+
i = len(weights) - 1
844845

845-
while len(weights) > 0:
846-
weights.pop()
847-
model = models[i]
846+
while len(weights) > 0:
847+
weights.pop()
848+
model = models[i]
848849

849-
sub_dir = "controlnet"
850-
model.save_pretrained(os.path.join(output_dir, sub_dir))
850+
sub_dir = "controlnet"
851+
model.save_pretrained(os.path.join(output_dir, sub_dir))
851852

852-
i -= 1
853+
i -= 1
853854

854855
def load_model_hook(models, input_dir):
855856
while len(models) > 0:

examples/dreambooth/train_dreambooth.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -920,12 +920,13 @@ def main(args):
920920

921921
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
922922
def save_model_hook(models, weights, output_dir):
923-
for model in models:
924-
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
925-
model.save_pretrained(os.path.join(output_dir, sub_dir))
923+
if accelerator.is_main_process:
924+
for model in models:
925+
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
926+
model.save_pretrained(os.path.join(output_dir, sub_dir))
926927

927-
# make sure to pop weight so that corresponding model is not saved again
928-
weights.pop()
928+
# make sure to pop weight so that corresponding model is not saved again
929+
weights.pop()
929930

930931
def load_model_hook(models, input_dir):
931932
while len(models) > 0:

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -894,27 +894,28 @@ def main(args):
894894

895895
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
896896
def save_model_hook(models, weights, output_dir):
897-
# there are only two options here. Either are just the unet attn processor layers
898-
# or there are the unet and text encoder atten layers
899-
unet_lora_layers_to_save = None
900-
text_encoder_lora_layers_to_save = None
901-
902-
for model in models:
903-
if isinstance(model, type(accelerator.unwrap_model(unet))):
904-
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
905-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
906-
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
907-
else:
908-
raise ValueError(f"unexpected save model: {model.__class__}")
897+
if accelerator.is_main_process:
898+
# there are only two options here. Either are just the unet attn processor layers
899+
# or there are the unet and text encoder atten layers
900+
unet_lora_layers_to_save = None
901+
text_encoder_lora_layers_to_save = None
902+
903+
for model in models:
904+
if isinstance(model, type(accelerator.unwrap_model(unet))):
905+
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
906+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
907+
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
908+
else:
909+
raise ValueError(f"unexpected save model: {model.__class__}")
909910

910-
# make sure to pop weight so that corresponding model is not saved again
911-
weights.pop()
911+
# make sure to pop weight so that corresponding model is not saved again
912+
weights.pop()
912913

913-
LoraLoaderMixin.save_lora_weights(
914-
output_dir,
915-
unet_lora_layers=unet_lora_layers_to_save,
916-
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
917-
)
914+
LoraLoaderMixin.save_lora_weights(
915+
output_dir,
916+
unet_lora_layers=unet_lora_layers_to_save,
917+
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
918+
)
918919

919920
def load_model_hook(models, input_dir):
920921
unet_ = None

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -798,31 +798,32 @@ def main(args):
798798

799799
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
800800
def save_model_hook(models, weights, output_dir):
801-
# there are only two options here. Either are just the unet attn processor layers
802-
# or there are the unet and text encoder atten layers
803-
unet_lora_layers_to_save = None
804-
text_encoder_one_lora_layers_to_save = None
805-
text_encoder_two_lora_layers_to_save = None
806-
807-
for model in models:
808-
if isinstance(model, type(accelerator.unwrap_model(unet))):
809-
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
810-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
811-
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
812-
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
813-
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
814-
else:
815-
raise ValueError(f"unexpected save model: {model.__class__}")
801+
if accelerator.is_main_process:
802+
# there are only two options here. Either are just the unet attn processor layers
803+
# or there are the unet and text encoder atten layers
804+
unet_lora_layers_to_save = None
805+
text_encoder_one_lora_layers_to_save = None
806+
text_encoder_two_lora_layers_to_save = None
807+
808+
for model in models:
809+
if isinstance(model, type(accelerator.unwrap_model(unet))):
810+
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
811+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
812+
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
813+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
814+
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
815+
else:
816+
raise ValueError(f"unexpected save model: {model.__class__}")
816817

817-
# make sure to pop weight so that corresponding model is not saved again
818-
weights.pop()
818+
# make sure to pop weight so that corresponding model is not saved again
819+
weights.pop()
819820

820-
StableDiffusionXLPipeline.save_lora_weights(
821-
output_dir,
822-
unet_lora_layers=unet_lora_layers_to_save,
823-
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
824-
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
825-
)
821+
StableDiffusionXLPipeline.save_lora_weights(
822+
output_dir,
823+
unet_lora_layers=unet_lora_layers_to_save,
824+
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
825+
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
826+
)
826827

827828
def load_model_hook(models, input_dir):
828829
unet_ = None

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -485,14 +485,15 @@ def main():
485485
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
486486
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
487487
def save_model_hook(models, weights, output_dir):
488-
if args.use_ema:
489-
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
488+
if accelerator.is_main_process:
489+
if args.use_ema:
490+
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
490491

491-
for i, model in enumerate(models):
492-
model.save_pretrained(os.path.join(output_dir, "unet"))
492+
for i, model in enumerate(models):
493+
model.save_pretrained(os.path.join(output_dir, "unet"))
493494

494-
# make sure to pop weight so that corresponding model is not saved again
495-
weights.pop()
495+
# make sure to pop weight so that corresponding model is not saved again
496+
weights.pop()
496497

497498
def load_model_hook(models, input_dir):
498499
if args.use_ema:

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,14 +528,15 @@ def main():
528528
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
529529
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
530530
def save_model_hook(models, weights, output_dir):
531-
if args.use_ema:
532-
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
531+
if accelerator.is_main_process:
532+
if args.use_ema:
533+
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
533534

534-
for i, model in enumerate(models):
535-
model.save_pretrained(os.path.join(output_dir, "unet"))
535+
for i, model in enumerate(models):
536+
model.save_pretrained(os.path.join(output_dir, "unet"))
536537

537-
# make sure to pop weight so that corresponding model is not saved again
538-
weights.pop()
538+
# make sure to pop weight so that corresponding model is not saved again
539+
weights.pop()
539540

540541
def load_model_hook(models, input_dir):
541542
if args.use_ema:

examples/research_projects/controlnet/train_controlnet_webdataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,16 +1010,17 @@ def main(args):
10101010
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
10111011
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
10121012
def save_model_hook(models, weights, output_dir):
1013-
i = len(weights) - 1
1013+
if accelerator.is_main_process:
1014+
i = len(weights) - 1
10141015

1015-
while len(weights) > 0:
1016-
weights.pop()
1017-
model = models[i]
1016+
while len(weights) > 0:
1017+
weights.pop()
1018+
model = models[i]
10181019

1019-
sub_dir = "controlnet"
1020-
model.save_pretrained(os.path.join(output_dir, sub_dir))
1020+
sub_dir = "controlnet"
1021+
model.save_pretrained(os.path.join(output_dir, sub_dir))
10211022

1022-
i -= 1
1023+
i -= 1
10231024

10241025
def load_model_hook(models, input_dir):
10251026
while len(models) > 0:

examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -552,14 +552,15 @@ def compute_snr(timesteps):
552552
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
553553
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
554554
def save_model_hook(models, weights, output_dir):
555-
if args.use_ema:
556-
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
555+
if accelerator.is_main_process:
556+
if args.use_ema:
557+
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
557558

558-
for i, model in enumerate(models):
559-
model.save_pretrained(os.path.join(output_dir, "unet"))
559+
for i, model in enumerate(models):
560+
model.save_pretrained(os.path.join(output_dir, "unet"))
560561

561-
# make sure to pop weight so that corresponding model is not saved again
562-
weights.pop()
562+
# make sure to pop weight so that corresponding model is not saved again
563+
weights.pop()
563564

564565
def load_model_hook(models, input_dir):
565566
if args.use_ema:

examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,15 @@ def main(args):
313313
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
314314
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
315315
def save_model_hook(models, weights, output_dir):
316-
if args.use_ema:
317-
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
316+
if accelerator.is_main_process:
317+
if args.use_ema:
318+
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
318319

319-
for i, model in enumerate(models):
320-
model.save_pretrained(os.path.join(output_dir, "unet"))
320+
for i, model in enumerate(models):
321+
model.save_pretrained(os.path.join(output_dir, "unet"))
321322

322-
# make sure to pop weight so that corresponding model is not saved again
323-
weights.pop()
323+
# make sure to pop weight so that corresponding model is not saved again
324+
weights.pop()
324325

325326
def load_model_hook(models, input_dir):
326327
if args.use_ema:

0 commit comments

Comments
 (0)