|
| 1 | +<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| 4 | +the License. You may obtain a copy of the License at |
| 5 | +
|
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +
|
| 8 | +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| 9 | +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| 10 | +specific language governing permissions and limitations under the License. |
| 11 | +
|
| 12 | +⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be |
| 13 | +rendered properly in your Markdown viewer. |
| 14 | +--> |
| 15 | + |
| 16 | +# Using multiple models with DeepSpeed |
| 17 | + |
| 18 | +<Tip warning={true}> |
| 19 | + |
| 20 | + This guide assumes that you have read and understood the [DeepSpeed usage guide](./deepspeed.md). |
| 21 | + |
| 22 | +</Tip> |
| 23 | + |
| 24 | +Running multiple models with Accelerate and DeepSpeed is useful for: |
| 25 | + |
| 26 | +* Knowledge distillation |
| 27 | +* Post-training techniques like RLHF (see the [TRL](https://github.com/huggingface/trl) library for more examples) |
| 28 | +* Training multiple models at once |
| 29 | + |
| 30 | +Currently, Accelerate has a **very experimental API** to help you use multiple models. |
| 31 | + |
| 32 | +This tutorial will focus on two common use cases: |
| 33 | + |
| 34 | +1. Knowledge distillation, where a smaller student model is trained to mimic a larger, better-performing teacher. If the student model fits on a single GPU, we can use ZeRO-2 for training and ZeRO-3 to shard the teacher for inference. This is significantly faster than using ZeRO-3 for both models. |
| 35 | +2. Training multiple *disjoint* models at once. |
| 36 | + |
| 37 | +## Knowledge distillation |
| 38 | + |
| 39 | +Knowledge distillation is a good example of using multiple models, but only training one of them. |
| 40 | + |
| 41 | +Normally, you would use a single [`utils.DeepSpeedPlugin`] for both models. However, in this case, there are two separate configurations. Accelerate allows you to create and use multiple plugins **if and only if** they are in a `dict` so that you can reference and enable the proper plugin when needed. |
| 42 | + |
| 43 | +```python |
| 44 | +from accelerate.utils import DeepSpeedPlugin |
| 45 | + |
| 46 | +zero2_plugin = DeepSpeedPlugin(hf_ds_config="zero2_config.json") |
| 47 | +zero3_plugin = DeepSpeedPlugin(hf_ds_config="zero3_config.json") |
| 48 | + |
| 49 | +deepspeed_plugins = {"student": zero2_plugin, "teacher": zero3_plugin} |
| 50 | +``` |
| 51 | + |
| 52 | +The `zero2_config.json` should be configured for full training (so specify `scheduler` and `optimizer` if you are not utilizing your own), while `zero3_config.json` should only be configured for the inference model, as shown in the example below. |
| 53 | + |
| 54 | +```json |
| 55 | +{ |
| 56 | + "bf16": { |
| 57 | + "enabled": "auto" |
| 58 | + }, |
| 59 | + "zero_optimization": { |
| 60 | + "stage": 3, |
| 61 | + "overlap_comm": true, |
| 62 | + "reduce_bucket_size": "auto", |
| 63 | + "stage3_prefetch_bucket_size": "auto", |
| 64 | + "stage3_param_persistence_threshold": "auto", |
| 65 | + "stage3_max_live_parameters": "auto", |
| 66 | + "stage3_max_reuse_distance": "auto", |
| 67 | + }, |
| 68 | + "train_micro_batch_size_per_gpu": 1 |
| 69 | +} |
| 70 | +``` |
| 71 | + |
| 72 | +An example `zero2_config.json` configuration is shown below. |
| 73 | + |
| 74 | +```json |
| 75 | +{ |
| 76 | + "bf16": { |
| 77 | + "enabled": "auto" |
| 78 | + }, |
| 79 | + "optimizer": { |
| 80 | + "type": "AdamW", |
| 81 | + "params": { |
| 82 | + "lr": "auto", |
| 83 | + "weight_decay": "auto", |
| 84 | + "torch_adam": true, |
| 85 | + "adam_w_mode": true |
| 86 | + } |
| 87 | + }, |
| 88 | + "scheduler": { |
| 89 | + "type": "WarmupLR", |
| 90 | + "params": { |
| 91 | + "warmup_min_lr": "auto", |
| 92 | + "warmup_max_lr": "auto", |
| 93 | + "warmup_num_steps": "auto" |
| 94 | + } |
| 95 | + }, |
| 96 | + "zero_optimization": { |
| 97 | + "stage": 2, |
| 98 | + "offload_optimizer": { |
| 99 | + "device": "cpu", |
| 100 | + "pin_memory": true |
| 101 | + }, |
| 102 | + }, |
| 103 | + "gradient_accumulation_steps": 1, |
| 104 | + "gradient_clipping": "auto", |
| 105 | + "train_batch_size": "auto", |
| 106 | + "train_micro_batch_size_per_gpu": "auto", |
| 107 | +} |
| 108 | +``` |
| 109 | + |
| 110 | +<Tip> |
| 111 | + |
| 112 | + DeepSpeed will raise an error if `train_micro_batch_size_per_gpu` isn't specified, even if this particular model isn't being trained. |
| 113 | + |
| 114 | +</Tip> |
| 115 | + |
| 116 | +From here, create a single [`Accelerator`] and pass in both configurations. |
| 117 | + |
| 118 | +```python |
| 119 | +from accelerate import Accelerator |
| 120 | + |
| 121 | +accelerator = Accelerator(deepspeed_plugins=deepspeed_plugins) |
| 122 | +``` |
| 123 | + |
| 124 | +Now let's see how to use them. |
| 125 | + |
| 126 | +### Student model |
| 127 | + |
| 128 | +By default, Accelerate sets the first item in the `dict` as the default or enabled plugin (`"student"` plugin). Verify this by using the [`utils.deepspeed.get_active_deepspeed_plugin`] function to see which plugin is enabled. |
| 129 | + |
| 130 | +```python |
| 131 | +active_plugin = get_active_deepspeed_plugin(accelerator.state) |
| 132 | +assert active_plugin is deepspeed_plugins["student"] |
| 133 | +``` |
| 134 | + |
| 135 | +[`AcceleratorState`] also keeps the active DeepSpeed plugin saved in `state.deepspeed_plugin`. |
| 136 | +```python |
| 137 | +assert active_plugin is accelerator.deepspeed_plugin |
| 138 | +``` |
| 139 | + |
| 140 | +Since `student` is the currently active plugin, let's go ahead and prepare the model, optimizer, and scheduler. |
| 141 | + |
| 142 | +```python |
| 143 | +student_model, optimizer, scheduler = ... |
| 144 | +student_model, optimizer, scheduler, train_dataloader = accelerator.prepare(student_model, optimizer, scheduler, train_dataloader) |
| 145 | +``` |
| 146 | + |
| 147 | +Now it's time to deal with the teacher model. |
| 148 | + |
| 149 | +### Teacher model |
| 150 | + |
| 151 | +First, you need to specify in [`Accelerator`] that the `zero3_config.json` configuration should be used. |
| 152 | + |
| 153 | +```python |
| 154 | +accelerator.state.select_deepspeed_plugin("teacher") |
| 155 | +``` |
| 156 | + |
| 157 | +This disables the `"student"` plugin and enables the `"teacher"` plugin instead. The |
| 158 | +DeepSpeed stateful config inside of Transformers is updated, and it changes which plugin configuration gets called when using |
| 159 | +`deepspeed.initialize()`. This allows you to use the automatic `deepspeed.zero.Init` context manager integration Transformers provides. |
| 160 | + |
| 161 | +```python |
| 162 | +teacher_model = AutoModel.from_pretrained(...) |
| 163 | +teacher_model = accelerator.prepare(teacher_model) |
| 164 | +``` |
| 165 | + |
| 166 | +Otherwise, you should manually initialize the model with `deepspeed.zero.Init`. |
| 167 | +```python |
| 168 | +with deepspeed.zero.Init(accelerator.deepspeed_plugin.config): |
| 169 | + model = MyModel(...) |
| 170 | +``` |
| 171 | + |
| 172 | +### Training |
| 173 | + |
| 174 | +From here, your training loop can be whatever you like, as long as `teacher_model` is never being trained on. |
| 175 | + |
| 176 | +```python |
| 177 | +teacher_model.eval() |
| 178 | +student_model.train() |
| 179 | +for batch in train_dataloader: |
| 180 | + with torch.no_grad(): |
| 181 | + output_teacher = teacher_model(**batch) |
| 182 | + output_student = student_model(**batch) |
| 183 | + # Combine the losses or modify it in some way |
| 184 | + loss = output_teacher.loss + output_student.loss |
| 185 | + accelerator.backward(loss) |
| 186 | + optimizer.step() |
| 187 | + scheduler.step() |
| 188 | + optimizer.zero_grad() |
| 189 | +``` |
| 190 | + |
| 191 | +## Train multiple disjoint models |
| 192 | + |
| 193 | +Training multiple models is a more complicated scenario. |
| 194 | +In its current state, we assume each model is **completely disjointed** from the other during training. |
| 195 | + |
| 196 | +This scenario still requires two [`utils.DeepSpeedPlugin`]'s to be made. However, you also need a second [`Accelerator`], since different `deepspeed` engines are being called at different times. A single [`Accelerator`] can only carry one instance at a time. |
| 197 | + |
| 198 | +Since the [`state.AcceleratorState`] is a stateful object though, it is already aware of both [`utils.DeepSpeedPlugin`]'s available. You can just instantiate a second [`Accelerator`] with no extra arguments. |
| 199 | + |
| 200 | +```python |
| 201 | +first_accelerator = Accelerator(deepspeed_plugins=deepspeed_plugins) |
| 202 | +second_accelerator = Accelerator() |
| 203 | +``` |
| 204 | + |
| 205 | +You can call either `first_accelerator.state.select_deepspeed_plugin()` to enable or disable |
| 206 | +a particular plugin, and then call [`prepare`]. |
| 207 | + |
| 208 | +```python |
| 209 | +# can be `accelerator_0`, `accelerator_1`, or by calling `AcceleratorState().select_deepspeed_plugin(...)` |
| 210 | +first_accelerator.state.select_deepspeed_plugin("first_model") |
| 211 | +first_model = AutoModel.from_pretrained(...) |
| 212 | +# For this example, `get_training_items` is a nonexistent function that gets the setup we need for training |
| 213 | +first_optimizer, first_scheduler, train_dl, eval_dl = get_training_items(model1) |
| 214 | +first_model, first_optimizer, first_scheduler, train_dl, eval_dl = accelerator.prepare( |
| 215 | + first_model, first_optimizer, first_scheduler, train_dl, eval_dl |
| 216 | +) |
| 217 | + |
| 218 | +second_accelerator.state.select_deepspeed_plugin("second_model") |
| 219 | +second_model = AutoModel.from_pretrained(...) |
| 220 | +# For this example, `get_training_items` is a nonexistent function that gets the setup we need for training |
| 221 | +second_optimizer, second_scheduler, _, _ = get_training_items(model2) |
| 222 | +second_model, second_optimizer, second_scheduler = accelerator.prepare( |
| 223 | + second_model, second_optimizer, second_scheduler |
| 224 | +) |
| 225 | +``` |
| 226 | + |
| 227 | +And now you can train: |
| 228 | + |
| 229 | +```python |
| 230 | +for batch in dl: |
| 231 | + outputs1 = first_model(**batch) |
| 232 | + first_accelerator.backward(outputs1.loss) |
| 233 | + first_optimizer.step() |
| 234 | + first_scheduler.step() |
| 235 | + first_optimizer.zero_grad() |
| 236 | + |
| 237 | + outputs2 = model2(**batch) |
| 238 | + second_accelerator.backward(outputs2.loss) |
| 239 | + second_optimizer.step() |
| 240 | + second_scheduler.step() |
| 241 | + second_optimizer.zero_grad() |
| 242 | +``` |
| 243 | + |
| 244 | +## Resources |
| 245 | + |
| 246 | +To see more examples, please check out the [related tests](https://github.com/huggingface/accelerate/blob/main/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py) currently in [Accelerate]. |
0 commit comments