Skip to content

Fixed the bug related to saving DeepSpeed models. #6628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 19, 2024

Conversation

HelloWorldBeginner
Copy link
Contributor

@HelloWorldBeginner HelloWorldBeginner commented Jan 18, 2024

What does this PR do?

When using DeepSpeed in the accelerate library to train a model, I encountered an issue while saving the checkpoint. I found that the model in save_model_hook is of type DeepSpeedEngine, which led to an "unexpected save model" error. To resolve this, I needed to unwrap the model, ensuring that it can be compared using isinstance for model type. After making these modifications, the model could be saved correctly.

Fixes # (issue)
fix this bug
image
after fix this bug ckpt can be saved

Steps:   4%|███▊                                                                                            | 80/2000 [03:38<53:16,  1.66s/it, lr=0.0001, step_loss=0.0303]01/18/2024 20:54:31 - INFO - accelerate.accelerator - Saving current state to sd-pokemon-model-lora-sdxl/checkpoint-80
01/18/2024 20:54:31 - INFO - accelerate.accelerator - Saving DeepSpeed Model and Optimizer
[2024-01-18 20:54:31,401] [INFO] [logging.py:96:log_dist] [Rank 0] [Torch] Checkpoint pytorch_model is about to be saved!
[2024-01-18 20:54:32,867] [INFO] [logging.py:96:log_dist] [Rank 0] Saving model checkpoint: sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/mp_rank_00_model_states.pt
[2024-01-18 20:54:32,868] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/mp_rank_00_model_states.pt...
[2024-01-18 20:54:46,415] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/mp_rank_00_model_states.pt.
[2024-01-18 20:54:46,472] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,473] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_1_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,473] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_3_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,473] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_2_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,474] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_5_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,473] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_4_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,476] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_6_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,477] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_7_mp_rank_00_optim_states.pt...
[2024-01-18 20:54:46,499] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt.
[2024-01-18 20:54:46,499] [INFO] [engine.py:3431:_save_zero_checkpoint] zero checkpoint saved sd-pokemon-model-lora-sdxl/checkpoint-80/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt
[2024-01-18 20:54:46,499] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint pytorch_model is ready now!

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @patrickvonplaten

HF projects:

-->

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

sayakpaul commented Jan 18, 2024

Thanks! Can you share an example training command to check with DeepSpeed?

I checked if the changes worked without DeepSpeed and they do: https://colab.research.google.com/gist/sayakpaul/6d60b261a42e0e9fb07c0e9505e7b82f/scratchpad.ipynb

@HelloWorldBeginner
Copy link
Contributor Author

HelloWorldBeginner commented Jan 19, 2024

Thanks! Can you share an example training command to check with DeepSpeed?

I checked if the changes worked without DeepSpeed and they do: https://colab.research.google.com/gist/sayakpaul/6d60b261a42e0e9fb07c0e9505e7b82f/scratchpad.ipynb

Thank you for your reply, here is my training script, datasets are from huggingface, I use single A100.

train shell

export MODEL_NAME="/home/mhh/sd_models/stable-diffusion-xl-base-1.0"
export VAE_NAME="/home/mhh/sd_models/sdxl-vae-fp16-fix"
export DATASET_NAME="/home/mhh/sd_datasets/pokemon-blip-captions"
export CUDA_VISIBLE_DEVICES="6"

DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
TEST_NAME="lora_fsdp_fp16"

accelerate launch  --config_file "./lora_dp_accelerate.yaml"  --main_process_port 12504 train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_model_name_or_path=$VAE_NAME \
  --dataset_name=$DATASET_NAME --caption_column="text" \
  --resolution=1024  \
  --train_batch_size=1 \
  --num_train_epochs=2 \
  --checkpointing_steps=2 \
  --learning_rate=1e-04 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --mixed_precision="fp16" \
  --max_train_steps=2000 \
  --validation_epochs=2000 \
  --seed=1234 \
  --output_dir="sd-pokemon-model-lora-sdxl" \
  --validation_prompt="cute dragon creature" | tee dp_logs/${TEST_NAME}_${DATETIME}.log

Here's my accelerate config file lora_dp_accelerate.yaml, training with deepspeed.

compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@sayakpaul
Copy link
Member

Thanks! And did you observe any speedups?

Additionally, if you could also modify the README_sdxl.md file to include a note about DeepSpeed, that would be nice. We can then merge ;-)

@HelloWorldBeginner
Copy link
Contributor Author

Thanks! And did you observe any speedups?

Additionally, if you could also modify the README_sdxl.md file to include a note about DeepSpeed, that would be nice. We can then merge ;-)

After using DeepSpeed, the GPU memory usage significantly decreases.
DeepSpeed can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes.

I can add this information about DeepSpeed to the README file later.

@sayakpaul
Copy link
Member

That's very good to know. Let's add this info to the README and we can then merge :)

@HelloWorldBeginner
Copy link
Contributor Author

That's very good to know. Let's add this info to the README and we can then merge :)

I have updated the README with instructions on how to train the SDXL model using DeepSpeed, please check :)

@sayakpaul
Copy link
Member

Looking fantastic. Will merge once the CI is green.

@HelloWorldBeginner
Copy link
Contributor Author

Looking fantastic. Will merge once the CI is green.

Thank you for your review, I hope it can help diffusers.

@sayakpaul sayakpaul merged commit f95615b into huggingface:main Jan 19, 2024
@sayakpaul
Copy link
Member

Of course it will!

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fixed the bug related to saving DeepSpeed models.

* Add information about training SD models using DeepSpeed to the README.

* Apply suggestions from code review

---------

Co-authored-by: mhh001 <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants