Skip to content

Commit ee74fb6

Browse files
authored
Introducing ORTPipelineModule - DeepSpeed Parallel Pipeline Support. (microsoft#20287)
### Description Introducing a new class ORTPipelineModule to handle wrapping layers in DeepSpeed pipeline parallel. ### Motivation and Context To support pipeline parallelism on ORTModule. This PR will include an initial support of deepspeed Pipeline parallelism. - [x] Support Pipeline parallel where layers are nn Modules in Sequential. - [ ] Support LayerSpec and TiedLayerSpec - [ ] Enable partitioning to accept List - [ ] Full-GPU Graph Consolidation - [ ] Subgraph Merging for Inference
1 parent f664f91 commit ee74fb6

File tree

7 files changed

+303
-0
lines changed

7 files changed

+303
-0
lines changed

cmake/onnxruntime_python.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ if (onnxruntime_ENABLE_TRAINING)
380380
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
381381
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
382382
)
383+
file(GLOB onnxruntime_python_ortmodule_pipe_srcs CONFIGURE_DEPENDS
384+
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/pipe/*"
385+
)
383386
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
384387
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
385388
)
@@ -756,6 +759,7 @@ if (onnxruntime_ENABLE_TRAINING)
756759
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
757760
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
758761
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
762+
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/pipe
759763
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
760764
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
761765
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
@@ -806,6 +810,9 @@ if (onnxruntime_ENABLE_TRAINING)
806810
COMMAND ${CMAKE_COMMAND} -E copy
807811
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
808812
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
813+
COMMAND ${CMAKE_COMMAND} -E copy
814+
${onnxruntime_python_ortmodule_pipe_srcs}
815+
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/pipe/
809816
COMMAND ${CMAKE_COMMAND} -E copy
810817
${onnxruntime_python_ort_triton_srcs}
811818
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/

docs/ORTModule_Training_Guidelines.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,31 @@ for epoch in range(start_epoch, n_epochs):
495495
```
496496
497497
Check [LoadBalancingDistributedBatchSampler implementation](../orttraining/orttraining/python/training/utils/data/sampler.py) for more details.
498+
499+
## 8 Using ORTPipelineModule for Deepspeed Pipeline Parallel
500+
501+
You can use `ORTPipelineModule` to support Deepspeed Pipeline Parallelism. Here's how you can integrate it into your pipeline:
502+
503+
```python
504+
from onnxruntime.training.ortmodule import DebugOptions
505+
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
506+
507+
# Create a debug configuration if needed
508+
# Since we're exporting multiple graphs here, this will generate multiple graphs with their index added as a prefix to differentiate them.
509+
510+
debug_options = DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name")
511+
512+
# Keep your deepspeed script the same and use ORTPipelineModule instead of PipelineModule
513+
# Initialize the ORTPipelineModule
514+
pipeline_module = ORTPipelineModule(
515+
layers,
516+
num_stages=2, # Set your number of stages
517+
base_seed=1234,
518+
partition_method="parameters",
519+
debug_options=debug_options # Pass the debug configuration if needed
520+
)
521+
522+
# Keep the rest of the script as it is.
523+
```
524+
525+
Check [ORTPipelineModule implementation](../orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py) for more details.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
from ._ort_pipeline_module import ORTPipelineModule # noqa: F401
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import importlib.metadata
5+
from functools import partial
6+
7+
import torch.nn as nn
8+
from deepspeed.pipe import LayerSpec, PipelineModule, TiedLayerSpec
9+
from deepspeed.runtime import utils as ds_utils
10+
from deepspeed.runtime.activation_checkpointing import checkpointing
11+
from packaging.version import Version
12+
13+
from onnxruntime.training.ortmodule import DebugOptions, ORTModule
14+
15+
# Check if DeepSpeed is installed and meets the minimum version requirement
16+
minimum_version = Version("0.9.0")
17+
installed_version = Version(importlib.metadata.version("deepspeed"))
18+
19+
if installed_version < minimum_version:
20+
raise ImportError(f"DeepSpeed >= {minimum_version} is required, but {installed_version} is installed.")
21+
22+
23+
class ORTPipelineModule(PipelineModule):
24+
"""ORTPipelineModule pipeline module.
25+
26+
A customized version of DeepSpeed's PipelineModule that wraps each neural network layer
27+
with ONNX Runtime's ORTModule. This modification allows leveraging ONNX Runtime optimizations
28+
for the forward and backward passes, potentially enhancing execution performance and efficiency.
29+
30+
Please locate the "Using ORTPipelineModule for Deepspeed Pipeline Parallel" section in the "docs/ORTModule_Training_Guidelines.md" file of the ORT repository for more information.
31+
32+
.. note::
33+
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
34+
35+
Args:
36+
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
37+
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
38+
topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
39+
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
40+
seed_layers(bool, optional): Use a different seed for each layer. Defaults to False.
41+
seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator.
42+
base_seed (int, optional): The starting seed. Defaults to 1234.
43+
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
44+
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
45+
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
46+
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
47+
debug_options(onnxruntime.training.ortmodule.DebugOptions): An instance of onnxruntime.training.ortmodule.DebugOptions or None.
48+
If provided, it will be used to configure debugging options for ORTModule, This is done so we can add the name of the layer to avoid overwriting the ONNX files.
49+
"""
50+
51+
def __init__(
52+
self,
53+
layers,
54+
num_stages=None,
55+
topology=None,
56+
loss_fn=None,
57+
seed_layers=False,
58+
seed_fn=None,
59+
base_seed=1234,
60+
partition_method="parameters",
61+
activation_checkpoint_interval=0,
62+
activation_checkpoint_func=checkpointing.checkpoint,
63+
checkpointable_layers=None,
64+
debug_options=None,
65+
):
66+
"""
67+
Initialize the ORTPipelineModule with the option to include ONNX Runtime debug options.
68+
"""
69+
70+
self.ort_kwargs = {"debug_options": debug_options} if debug_options is not None else {}
71+
72+
super().__init__(
73+
layers,
74+
num_stages,
75+
topology,
76+
loss_fn,
77+
seed_layers,
78+
seed_fn,
79+
base_seed,
80+
partition_method,
81+
activation_checkpoint_interval,
82+
activation_checkpoint_func,
83+
checkpointable_layers,
84+
)
85+
86+
def _build(self):
87+
"""
88+
This method does the same thing as PipelineModule._build() method, the only difference is that it wraps each layer with ORTModule.
89+
It also handles saving ONNX models with debug options in case of exporting multiple models.
90+
"""
91+
specs = self._layer_specs
92+
93+
for local_idx, layer in enumerate(specs[self._local_start : self._local_stop]):
94+
layer_idx = local_idx + self._local_start
95+
if self.seed_layers:
96+
if self.seed_fn:
97+
self.seed_fn(self.base_seed + layer_idx)
98+
else:
99+
ds_utils.set_random_seed(self.base_seed + layer_idx)
100+
101+
# Recursively build PipelineModule objects
102+
if isinstance(layer, PipelineModule):
103+
raise NotImplementedError("RECURSIVE BUILD NOT YET IMPLEMENTED")
104+
105+
# TODO: Support wrapping for LayerSpec and TiedLayerSpec in addition to nn.Module in sequential.
106+
# Currently, we only support wrapping nn.Module instances.
107+
108+
# LayerSpec objects contain an nn.Module that should be allocated now.
109+
elif isinstance(layer, nn.Module):
110+
name = str(layer_idx)
111+
112+
if "debug_options" in self.ort_kwargs:
113+
new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix
114+
parallel_debug_options = DebugOptions(
115+
self.ort_kwargs["debug_options"].log_level,
116+
self.ort_kwargs["debug_options"].save_onnx,
117+
new_onnx_prefix,
118+
)
119+
wrapped_layer = ORTModule(layer, parallel_debug_options)
120+
else:
121+
wrapped_layer = ORTModule(layer)
122+
123+
self.forward_funcs.append(wrapped_layer)
124+
self.fwd_map.update({name: len(self.forward_funcs) - 1})
125+
self.add_module(name, wrapped_layer)
126+
127+
# TiedLayerSpec objects contain an nn.Module that should be allocated now.
128+
elif isinstance(layer, TiedLayerSpec):
129+
# Build and register the module if we haven't seen it before.
130+
if layer.key not in self.tied_modules:
131+
self.tied_modules[layer.key] = layer.build()
132+
self.tied_weight_attrs[layer.key] = layer.tied_weight_attr
133+
134+
if layer.forward_fn is None:
135+
# Just use forward()
136+
self.forward_funcs.append(self.tied_modules[layer.key])
137+
else:
138+
# User specified fn with args (module, input)
139+
self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key]))
140+
141+
# LayerSpec objects contain an nn.Module that should be allocated now.
142+
elif isinstance(layer, LayerSpec):
143+
module = layer.build()
144+
name = str(layer_idx)
145+
self.forward_funcs.append(module)
146+
self.fwd_map.update({name: len(self.forward_funcs) - 1})
147+
self.add_module(name, module)
148+
149+
# Last option: layer may be a functional (e.g., lambda). We do nothing in
150+
# that case and just use it in forward()
151+
else:
152+
self.forward_funcs.append(layer)
153+
154+
# All pipeline parameters should be considered as model parallel in the context
155+
# of our FP16 optimizer
156+
for p in self.parameters():
157+
p.ds_pipe_replicated = False

orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log):
5656
run_subprocess(command, cwd=cwd, log=log).check_returncode()
5757

5858

59+
def run_ort_pipeline_module_tests(cwd, log):
60+
log.debug("Running: ORTPipelineModule tests")
61+
62+
command = [
63+
"deepspeed",
64+
"orttraining_test_ort_pipeline_module.py",
65+
"--deepspeed_config",
66+
"orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json",
67+
]
68+
69+
run_subprocess(command, cwd=cwd, log=log).check_returncode()
70+
71+
5972
def run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, data_dir):
6073
log.debug("Running: ORTModule fairscale sharded optimizer tests")
6174
command = [
@@ -94,6 +107,7 @@ def main():
94107
run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist)
95108

96109
run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log)
110+
run_ort_pipeline_module_tests(cwd, log)
97111
run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist)
98112

99113
run_distributed_cache_test(cwd, log)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
3+
import deepspeed
4+
import torch
5+
from torch import nn
6+
7+
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
8+
9+
# USAGE:
10+
# pip install deepspeed
11+
# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100
12+
# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694
13+
14+
15+
class SampleData(torch.utils.data.Dataset):
16+
def __init__(self, x, y):
17+
self.x = x
18+
self.y = y
19+
20+
def __len__(self):
21+
return x.size()[0]
22+
23+
def __getitem__(self, idx):
24+
return self.x[idx], self.y[idx]
25+
26+
27+
def get_args():
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher")
30+
parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps")
31+
parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism")
32+
parser.add_argument("--backend", type=str, default="nccl", help="distributed backend")
33+
parser.add_argument("--seed", type=int, default=0, help="PRNG seed")
34+
parser.add_argument("--fp16", type=bool, default=False, help="fp16 run")
35+
36+
parser = deepspeed.add_config_arguments(parser)
37+
args = parser.parse_args()
38+
return args
39+
40+
41+
n = 10
42+
d_in = 4
43+
d_hidden = 8
44+
d_out = 3
45+
args = get_args()
46+
torch.cuda.set_device(args.local_rank)
47+
device = torch.device("cuda", args.local_rank)
48+
49+
# dist.init_process_group(backend=args.backend)
50+
deepspeed.init_distributed(dist_backend=args.backend)
51+
torch.manual_seed(args.seed)
52+
# Model.
53+
54+
model = nn.Sequential(
55+
nn.Linear(d_in, d_hidden), # Stage 1
56+
nn.ReLU(), # Stage 1
57+
nn.Linear(d_hidden, d_hidden), # Stage 1
58+
nn.ReLU(), # Stage 1
59+
nn.Linear(d_hidden, d_hidden), # Stage 2
60+
nn.ReLU(), # Stage 2
61+
nn.Linear(d_hidden, d_out), # Stage 2
62+
)
63+
64+
model = ORTPipelineModule(
65+
layers=model,
66+
loss_fn=torch.nn.CrossEntropyLoss(),
67+
num_stages=args.pipeline_parallel_size,
68+
partition_method="uniform", #'parameters',
69+
activation_checkpoint_interval=0,
70+
)
71+
72+
params = [p for p in model.parameters() if p.requires_grad]
73+
74+
# Input.
75+
x = torch.rand((n, d_in))
76+
if args.fp16:
77+
x = x.half()
78+
# Output.
79+
y = torch.randint(0, d_out, (n,))
80+
ds = SampleData(x, y)
81+
82+
print("Initialize deepspeed")
83+
model_engine, optimizer, _, _ = deepspeed.initialize(
84+
args=args, model=model, model_parameters=params, training_data=ds # (x,y)#
85+
)
86+
87+
for step in range(args.steps):
88+
loss = model_engine.train_batch()
89+
if step % 10 == 0:
90+
print("step = ", step, ", loss = ", loss)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def finalize_options(self):
486486
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
487487
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
488488
"onnxruntime.training.ortmodule.graph_optimizers",
489+
"onnxruntime.training.ortmodule.experimental.pipe",
489490
"onnxruntime.training.ort_triton",
490491
"onnxruntime.training.ort_triton.kernel",
491492
"onnxruntime.training.utils",

0 commit comments

Comments
 (0)