|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# This file applies the PT-D pipeline parallelism to the Llama model. |
| 8 | + |
| 9 | +import copy |
| 10 | + |
| 11 | +import torch.nn as nn |
| 12 | +from torch.distributed import DeviceMesh |
| 13 | +from torch.distributed.pipelining import PipelineStage |
| 14 | +from torch.distributed.pipelining.schedules import ( |
| 15 | + _PipelineSchedule, |
| 16 | + get_schedule_class, |
| 17 | + PipelineScheduleSingle, |
| 18 | + ScheduleZBVZeroBubble, |
| 19 | +) |
| 20 | + |
| 21 | +from torchtitan.components.loss import LossFunction |
| 22 | +from torchtitan.config_manager import JobConfig |
| 23 | +from torchtitan.distributed import ParallelDims |
| 24 | +from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank |
| 25 | +from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction |
| 26 | +from torchtitan.tools.logging import logger |
| 27 | + |
| 28 | +from ..model.args import DeepSeekV3ModelArgs |
| 29 | + |
| 30 | + |
| 31 | +def generate_module_names_per_stage( |
| 32 | + num_stages: int, |
| 33 | + num_layers: int, |
| 34 | + input_weight: int = 1, |
| 35 | + output_weight: int = 1, |
| 36 | +) -> list[list[str]]: |
| 37 | + """ |
| 38 | + Programmatically generates module names per stage for pipeline parallelism with weighting. |
| 39 | +
|
| 40 | + Args: |
| 41 | + num_stages: Number of pipeline stages |
| 42 | + num_layers: Total number of transformer layers in the model |
| 43 | + input_weight: Weight for input modules (tok_embeddings) in layer calculation |
| 44 | + output_weight: Weight for output modules (norm + output) in layer calculation |
| 45 | +
|
| 46 | + Returns: |
| 47 | + List of lists containing module names for each stage |
| 48 | +
|
| 49 | + Example: |
| 50 | + generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2) |
| 51 | + treats embeddings as 2 layers and norm+output as 2 layers for distribution |
| 52 | + """ |
| 53 | + if num_stages < 1: |
| 54 | + raise ValueError("Number of stages must be at least 1") |
| 55 | + |
| 56 | + if num_stages == 1: |
| 57 | + # Single stage gets everything |
| 58 | + layer_names = [f"layers.{i}" for i in range(num_layers)] |
| 59 | + return [["tok_embeddings"] + layer_names + ["norm", "output"]] |
| 60 | + |
| 61 | + # Calculate effective layers including weights |
| 62 | + num_effective_layers = num_layers + input_weight + output_weight |
| 63 | + |
| 64 | + if num_stages > num_effective_layers: |
| 65 | + raise ValueError( |
| 66 | + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" |
| 67 | + ) |
| 68 | + |
| 69 | + # Calculate layers per stage (distribute evenly) |
| 70 | + layers_per_stage = num_effective_layers // num_stages |
| 71 | + extra_layers = num_effective_layers % num_stages |
| 72 | + |
| 73 | + # Ensure each stage gets at least the weight of input/output modules |
| 74 | + if layers_per_stage < max(input_weight, output_weight): |
| 75 | + raise ValueError( |
| 76 | + f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})" |
| 77 | + ) |
| 78 | + |
| 79 | + module_names_per_stage = [] |
| 80 | + current_layer = 0 |
| 81 | + |
| 82 | + for stage_idx in range(num_stages): |
| 83 | + stage_modules = [] |
| 84 | + |
| 85 | + # Calculate effective layers for this stage |
| 86 | + effective_layers_for_stage = layers_per_stage |
| 87 | + if stage_idx < extra_layers: |
| 88 | + effective_layers_for_stage += 1 |
| 89 | + |
| 90 | + # First stage: handle input modules with weighting |
| 91 | + if stage_idx == 0: |
| 92 | + stage_modules.append("tok_embeddings") |
| 93 | + # Account for input weight in layer distribution |
| 94 | + remaining_layers_for_stage = effective_layers_for_stage - input_weight |
| 95 | + |
| 96 | + # Add transformer layers |
| 97 | + for _ in range(remaining_layers_for_stage): |
| 98 | + if current_layer < num_layers: |
| 99 | + stage_modules.append(f"layers.{current_layer}") |
| 100 | + current_layer += 1 |
| 101 | + |
| 102 | + # Last stage: handle output modules with weighting |
| 103 | + elif stage_idx == num_stages - 1: |
| 104 | + # Account for output weight in layer distribution |
| 105 | + remaining_layers_for_stage = effective_layers_for_stage - output_weight |
| 106 | + |
| 107 | + # Add transformer layers |
| 108 | + for _ in range(remaining_layers_for_stage): |
| 109 | + if current_layer < num_layers: |
| 110 | + stage_modules.append(f"layers.{current_layer}") |
| 111 | + current_layer += 1 |
| 112 | + |
| 113 | + # Add output modules |
| 114 | + stage_modules.extend(["norm", "output"]) |
| 115 | + |
| 116 | + # Middle stages: only transformer layers |
| 117 | + else: |
| 118 | + for _ in range(effective_layers_for_stage): |
| 119 | + if current_layer < num_layers: |
| 120 | + stage_modules.append(f"layers.{current_layer}") |
| 121 | + current_layer += 1 |
| 122 | + |
| 123 | + module_names_per_stage.append(stage_modules) |
| 124 | + |
| 125 | + return module_names_per_stage |
| 126 | + |
| 127 | + |
| 128 | +def pipeline_deepseekv3( |
| 129 | + model: nn.Module, |
| 130 | + world_mesh: DeviceMesh, |
| 131 | + parallel_dims: ParallelDims, |
| 132 | + job_config: JobConfig, |
| 133 | + device: DeviceType, |
| 134 | + model_config: DeepSeekV3ModelArgs, |
| 135 | + parallelize_fn: ParallelizeFunction, |
| 136 | + loss_fn: LossFunction, |
| 137 | +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: |
| 138 | + pp_mesh = world_mesh["pp"] |
| 139 | + |
| 140 | + # Determine the number of virtual stages based on schedule type |
| 141 | + schedule_class = get_schedule_class( |
| 142 | + job_config.parallelism.pipeline_parallel_schedule |
| 143 | + ) |
| 144 | + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) |
| 145 | + |
| 146 | + # For multi-stage schedules, default is 2 virtual stages per rank |
| 147 | + # For single-stage schedules, default is 1 virtual stage per rank |
| 148 | + stages_per_rank = 1 if is_single_stage_schedule else 2 |
| 149 | + num_virtual_stages = parallel_dims.pp * stages_per_rank |
| 150 | + |
| 151 | + # Generate module names per stage programmatically with weighting |
| 152 | + num_layers = model_config.n_layers |
| 153 | + |
| 154 | + # You can adjust these weights based on the computational cost of embeddings and output layers |
| 155 | + # Higher weights mean these modules are treated as "heavier" in the distribution |
| 156 | + input_weight = 1 # Weight for tok_embeddings |
| 157 | + output_weight = 1 # Weight for norm + output layers |
| 158 | + |
| 159 | + module_names_per_stage = generate_module_names_per_stage( |
| 160 | + num_virtual_stages, num_layers, input_weight, output_weight |
| 161 | + ) |
| 162 | + for i, stage_ms in enumerate(module_names_per_stage): |
| 163 | + logger.info(f"Stage {i}: {stage_ms}") |
| 164 | + |
| 165 | + stages, model_parts = pipeline_module_split( |
| 166 | + model, |
| 167 | + pp_mesh, |
| 168 | + job_config.parallelism.pipeline_parallel_schedule, |
| 169 | + device, |
| 170 | + module_names_per_stage, |
| 171 | + ) |
| 172 | + |
| 173 | + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. |
| 174 | + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, |
| 175 | + # optimizer, and checkpointing |
| 176 | + for i, m in enumerate(model_parts): |
| 177 | + # apply SPMD-style PT-D techniques |
| 178 | + m = parallelize_fn(m, world_mesh, parallel_dims, job_config) |
| 179 | + model_parts[i] = m |
| 180 | + # NOTE: this is to update the model in the stage |
| 181 | + # in case the model is modified e.g. by torch.compile |
| 182 | + stages[i].submod = m |
| 183 | + |
| 184 | + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) |
| 185 | + |
| 186 | + # This is used in the train loop to determine whether to pass in the input_ids and labels |
| 187 | + has_first_stage = False |
| 188 | + has_last_stage = False |
| 189 | + for stage in stages: |
| 190 | + if stage.is_first: |
| 191 | + has_first_stage = True |
| 192 | + if stage.is_last: |
| 193 | + has_last_stage = True |
| 194 | + |
| 195 | + return pp_schedule, model_parts, has_first_stage, has_last_stage |
| 196 | + |
| 197 | + |
| 198 | +def pipeline_module_split( |
| 199 | + whole_model: nn.Module, |
| 200 | + pp_mesh: DeviceMesh, |
| 201 | + pp_schedule: str, |
| 202 | + device: DeviceType, |
| 203 | + module_names_per_stage: list[list[str]], |
| 204 | +) -> tuple[list[PipelineStage], list[nn.Module]]: |
| 205 | + """ |
| 206 | + This API creates pipeline stages based on specified module names for each stage. |
| 207 | +
|
| 208 | + Args: |
| 209 | + whole_model: The complete model to be split |
| 210 | + pp_mesh: Pipeline parallel device mesh |
| 211 | + pp_schedule: Name of pipeline parallelism schedule |
| 212 | + device: Device type |
| 213 | + module_names_per_stage: List of lists, where each inner list contains the module names |
| 214 | + that should be included in that stage. Module names should be |
| 215 | + dot-separated paths. Examples: |
| 216 | + - "tok_embeddings" for token embeddings |
| 217 | + - "layers.0", "layers.1" for specific transformer layers |
| 218 | + - "norm" for the final normalization layer |
| 219 | + - "output" for the output projection layer |
| 220 | +
|
| 221 | + Returns: |
| 222 | + Tuple of (stages, models) where stages are PipelineStage objects and models are the |
| 223 | + corresponding model chunks |
| 224 | +
|
| 225 | + Example usage: |
| 226 | + module_names_per_stage = [ |
| 227 | + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer |
| 228 | + ["layers.1", "layers.2"], # Stage 1: middle layers |
| 229 | + ["norm", "output"] # Stage 2: final norm + output |
| 230 | + ] |
| 231 | + """ |
| 232 | + pp_rank = pp_mesh.get_local_rank() |
| 233 | + pp_size = pp_mesh.size() |
| 234 | + |
| 235 | + def _build_stage_from_modules( |
| 236 | + stage_idx: int, module_names: list[str], num_stages: int |
| 237 | + ) -> tuple[PipelineStage, nn.Module]: |
| 238 | + model = copy.deepcopy(whole_model) |
| 239 | + |
| 240 | + # Create a set of modules to keep for faster lookup |
| 241 | + modules_to_keep = set(module_names) |
| 242 | + print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") |
| 243 | + for module_name, module_value in model.named_children(): |
| 244 | + # Handle layer-like structures (e.g., "layers.0", "layers.1") |
| 245 | + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): |
| 246 | + layers_to_keep = { |
| 247 | + name.split(".", 1)[1] |
| 248 | + for name in modules_to_keep |
| 249 | + if name.startswith(f"{module_name}.") |
| 250 | + } |
| 251 | + if layers_to_keep: |
| 252 | + # Keep only specified layers |
| 253 | + if isinstance(module_value, nn.ModuleDict): |
| 254 | + for layer_name in list(module_value.keys()): |
| 255 | + if layer_name not in layers_to_keep: |
| 256 | + del module_value[layer_name] |
| 257 | + elif isinstance(module_value, nn.ModuleList): |
| 258 | + indices_to_keep = { |
| 259 | + int(idx) for idx in layers_to_keep if idx.isdigit() |
| 260 | + } |
| 261 | + new_layers = nn.ModuleList( |
| 262 | + [ |
| 263 | + layer |
| 264 | + for i, layer in enumerate(module_value) |
| 265 | + if i in indices_to_keep |
| 266 | + ] |
| 267 | + ) |
| 268 | + setattr(model, module_name, new_layers) |
| 269 | + else: |
| 270 | + # No layers from this structure needed, set to empty structure |
| 271 | + if isinstance(module_value, nn.ModuleDict): |
| 272 | + setattr(model, module_name, nn.ModuleDict()) |
| 273 | + elif isinstance(module_value, nn.ModuleList): |
| 274 | + setattr(model, module_name, nn.ModuleList()) |
| 275 | + # Handle simple module attributes (e.g., "linear", "norm") |
| 276 | + elif module_name not in modules_to_keep: |
| 277 | + # Replace with identity module instead of None |
| 278 | + setattr(model, module_name, nn.Identity()) |
| 279 | + |
| 280 | + stage = PipelineStage( |
| 281 | + model, |
| 282 | + stage_idx, |
| 283 | + num_stages, |
| 284 | + device, |
| 285 | + group=pp_mesh.get_group("pp"), |
| 286 | + ) |
| 287 | + return stage, model |
| 288 | + |
| 289 | + num_stages = len(module_names_per_stage) |
| 290 | + stages = [] |
| 291 | + models = [] |
| 292 | + |
| 293 | + schedule_class = get_schedule_class(pp_schedule) |
| 294 | + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" |
| 295 | + |
| 296 | + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): |
| 297 | + module_names = module_names_per_stage[stage_idx] |
| 298 | + stage, model_chunk = _build_stage_from_modules( |
| 299 | + stage_idx, |
| 300 | + module_names, |
| 301 | + num_stages, |
| 302 | + ) |
| 303 | + logger.info( |
| 304 | + f"PP rank {pp_rank} is building stage_idx {stage_idx} " |
| 305 | + f"with modules {module_names}" |
| 306 | + ) |
| 307 | + stages.append(stage) |
| 308 | + models.append(model_chunk) |
| 309 | + |
| 310 | + return stages, models |
0 commit comments