|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import importlib.metadata |
| 5 | +from functools import partial |
| 6 | + |
| 7 | +import torch.nn as nn |
| 8 | +from deepspeed.pipe import LayerSpec, PipelineModule, TiedLayerSpec |
| 9 | +from deepspeed.runtime import utils as ds_utils |
| 10 | +from deepspeed.runtime.activation_checkpointing import checkpointing |
| 11 | +from packaging.version import Version |
| 12 | + |
| 13 | +from onnxruntime.training.ortmodule import DebugOptions, ORTModule |
| 14 | + |
| 15 | +# Check if DeepSpeed is installed and meets the minimum version requirement |
| 16 | +minimum_version = Version("0.9.0") |
| 17 | +installed_version = Version(importlib.metadata.version("deepspeed")) |
| 18 | + |
| 19 | +if installed_version < minimum_version: |
| 20 | + raise ImportError(f"DeepSpeed >= {minimum_version} is required, but {installed_version} is installed.") |
| 21 | + |
| 22 | + |
| 23 | +class ORTPipelineModule(PipelineModule): |
| 24 | + """ORTPipelineModule pipeline module. |
| 25 | +
|
| 26 | + A customized version of DeepSpeed's PipelineModule that wraps each neural network layer |
| 27 | + with ONNX Runtime's ORTModule. This modification allows leveraging ONNX Runtime optimizations |
| 28 | + for the forward and backward passes, potentially enhancing execution performance and efficiency. |
| 29 | +
|
| 30 | + Please locate the "Using ORTPipelineModule for Deepspeed Pipeline Parallel" section in the "docs/ORTModule_Training_Guidelines.md" file of the ORT repository for more information. |
| 31 | +
|
| 32 | + .. note:: |
| 33 | + Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3. |
| 34 | +
|
| 35 | + Args: |
| 36 | + layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module. |
| 37 | + num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided. |
| 38 | + topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``. |
| 39 | + loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)`` |
| 40 | + seed_layers(bool, optional): Use a different seed for each layer. Defaults to False. |
| 41 | + seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator. |
| 42 | + base_seed (int, optional): The starting seed. Defaults to 1234. |
| 43 | + partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. |
| 44 | + activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. |
| 45 | + activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. |
| 46 | + checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. |
| 47 | + debug_options(onnxruntime.training.ortmodule.DebugOptions): An instance of onnxruntime.training.ortmodule.DebugOptions or None. |
| 48 | + If provided, it will be used to configure debugging options for ORTModule, This is done so we can add the name of the layer to avoid overwriting the ONNX files. |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + layers, |
| 54 | + num_stages=None, |
| 55 | + topology=None, |
| 56 | + loss_fn=None, |
| 57 | + seed_layers=False, |
| 58 | + seed_fn=None, |
| 59 | + base_seed=1234, |
| 60 | + partition_method="parameters", |
| 61 | + activation_checkpoint_interval=0, |
| 62 | + activation_checkpoint_func=checkpointing.checkpoint, |
| 63 | + checkpointable_layers=None, |
| 64 | + debug_options=None, |
| 65 | + ): |
| 66 | + """ |
| 67 | + Initialize the ORTPipelineModule with the option to include ONNX Runtime debug options. |
| 68 | + """ |
| 69 | + |
| 70 | + self.ort_kwargs = {"debug_options": debug_options} if debug_options is not None else {} |
| 71 | + |
| 72 | + super().__init__( |
| 73 | + layers, |
| 74 | + num_stages, |
| 75 | + topology, |
| 76 | + loss_fn, |
| 77 | + seed_layers, |
| 78 | + seed_fn, |
| 79 | + base_seed, |
| 80 | + partition_method, |
| 81 | + activation_checkpoint_interval, |
| 82 | + activation_checkpoint_func, |
| 83 | + checkpointable_layers, |
| 84 | + ) |
| 85 | + |
| 86 | + def _build(self): |
| 87 | + """ |
| 88 | + This method does the same thing as PipelineModule._build() method, the only difference is that it wraps each layer with ORTModule. |
| 89 | + It also handles saving ONNX models with debug options in case of exporting multiple models. |
| 90 | + """ |
| 91 | + specs = self._layer_specs |
| 92 | + |
| 93 | + for local_idx, layer in enumerate(specs[self._local_start : self._local_stop]): |
| 94 | + layer_idx = local_idx + self._local_start |
| 95 | + if self.seed_layers: |
| 96 | + if self.seed_fn: |
| 97 | + self.seed_fn(self.base_seed + layer_idx) |
| 98 | + else: |
| 99 | + ds_utils.set_random_seed(self.base_seed + layer_idx) |
| 100 | + |
| 101 | + # Recursively build PipelineModule objects |
| 102 | + if isinstance(layer, PipelineModule): |
| 103 | + raise NotImplementedError("RECURSIVE BUILD NOT YET IMPLEMENTED") |
| 104 | + |
| 105 | + # TODO: Support wrapping for LayerSpec and TiedLayerSpec in addition to nn.Module in sequential. |
| 106 | + # Currently, we only support wrapping nn.Module instances. |
| 107 | + |
| 108 | + # LayerSpec objects contain an nn.Module that should be allocated now. |
| 109 | + elif isinstance(layer, nn.Module): |
| 110 | + name = str(layer_idx) |
| 111 | + |
| 112 | + if "debug_options" in self.ort_kwargs: |
| 113 | + new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix |
| 114 | + parallel_debug_options = DebugOptions( |
| 115 | + self.ort_kwargs["debug_options"].log_level, |
| 116 | + self.ort_kwargs["debug_options"].save_onnx, |
| 117 | + new_onnx_prefix, |
| 118 | + ) |
| 119 | + wrapped_layer = ORTModule(layer, parallel_debug_options) |
| 120 | + else: |
| 121 | + wrapped_layer = ORTModule(layer) |
| 122 | + |
| 123 | + self.forward_funcs.append(wrapped_layer) |
| 124 | + self.fwd_map.update({name: len(self.forward_funcs) - 1}) |
| 125 | + self.add_module(name, wrapped_layer) |
| 126 | + |
| 127 | + # TiedLayerSpec objects contain an nn.Module that should be allocated now. |
| 128 | + elif isinstance(layer, TiedLayerSpec): |
| 129 | + # Build and register the module if we haven't seen it before. |
| 130 | + if layer.key not in self.tied_modules: |
| 131 | + self.tied_modules[layer.key] = layer.build() |
| 132 | + self.tied_weight_attrs[layer.key] = layer.tied_weight_attr |
| 133 | + |
| 134 | + if layer.forward_fn is None: |
| 135 | + # Just use forward() |
| 136 | + self.forward_funcs.append(self.tied_modules[layer.key]) |
| 137 | + else: |
| 138 | + # User specified fn with args (module, input) |
| 139 | + self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key])) |
| 140 | + |
| 141 | + # LayerSpec objects contain an nn.Module that should be allocated now. |
| 142 | + elif isinstance(layer, LayerSpec): |
| 143 | + module = layer.build() |
| 144 | + name = str(layer_idx) |
| 145 | + self.forward_funcs.append(module) |
| 146 | + self.fwd_map.update({name: len(self.forward_funcs) - 1}) |
| 147 | + self.add_module(name, module) |
| 148 | + |
| 149 | + # Last option: layer may be a functional (e.g., lambda). We do nothing in |
| 150 | + # that case and just use it in forward() |
| 151 | + else: |
| 152 | + self.forward_funcs.append(layer) |
| 153 | + |
| 154 | + # All pipeline parameters should be considered as model parallel in the context |
| 155 | + # of our FP16 optimizer |
| 156 | + for p in self.parameters(): |
| 157 | + p.ds_pipe_replicated = False |
0 commit comments