From 7aabc71450eed298ce146a10586336163503cae0 Mon Sep 17 00:00:00 2001 From: Hann Wang Date: Fri, 6 Jun 2025 15:26:37 +0800 Subject: [PATCH 1/6] [llama4] enable expert parallel on the same device mesh as tp (tp2ep) --- .gitignore | 1 + torchtitan/config_manager.py | 3 + .../llama4/infra/expert_parallel.py | 36 ++++ .../llama4/infra/parallelize_llama.py | 78 +++++-- torchtitan/experiments/llama4/model/moe.py | 201 ++++++++++++++---- .../llama4/train_configs/debug_model.toml | 1 + .../llama4/train_configs/llama4_17bx128e.toml | 1 + .../llama4/train_configs/llama4_17bx16e.toml | 1 + 8 files changed, 264 insertions(+), 58 deletions(-) diff --git a/.gitignore b/.gitignore index e39990d72..21e3a56b4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build outputs dist/* .vscode +slurm-*.out # data data diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4300c3bb8..c2a686605 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -353,6 +353,9 @@ class Parallelism: - 'alltoall' means to all-to-all shuffle the kv shards. The default value is 'allgather'. """ + + enable_tp2ep: bool = False + """Whether to use expert parallelism instead of tensor parallelism for shared experts.""" @dataclass diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 68f2e7a75..37f4b083f 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -141,3 +141,39 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) + + +class ExpertParallel(ParallelStyle): + + def __init__(self, ): + super().__init__() + + @staticmethod + def _prepare_input_fn(mod, inputs, device_mesh): + for inp in inputs: + if isinstance(inp, torch.Tensor): + assert not isinstance( + inp, DTensor), "ExpertParallel expects local tensor inputs." + return inputs + + def _partition_fn(self, name, module, device_mesh): + # shard on the expert dimension + for name, param in module.named_parameters(recurse=False): + dist_param = nn.Parameter( + distribute_tensor(param, device_mesh, [Shard(0)])) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(mod, outputs, device_mesh): + assert not isinstance( + outputs, DTensor), "ExpertParallel expects local tensor outputs." + return outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + self._prepare_input_fn, + self._prepare_output_fn, + ) diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py index 785d9d8a5..49b3bf48c 100644 --- a/torchtitan/experiments/llama4/infra/parallelize_llama.py +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -64,7 +64,11 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - apply_moe_tp(model, world_mesh["tp"]) + apply_moe_tp( + model, + world_mesh["tp"], + enable_tp2ep=job_config.parallelism.enable_tp2ep, + ) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -145,6 +149,7 @@ def _sync_tokens_per_expert(module, *_): def apply_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, + enable_tp2ep: bool = False, ): from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( @@ -152,25 +157,62 @@ def apply_moe_tp( PrepareModuleInputOutput, ) - from .expert_parallel import NoParallel, TensorParallel + from .expert_parallel import ( + NoParallel, + TensorParallel, + ExpertParallel, + ) for transformer_block in model.layers.values(): - moe_layer_plan = { - # input / output sharding on the seqlen dim - # all-gather for input, reduce-scatter for output - "moe": PrepareModuleInputOutput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_input=True, - output_layouts=(Partial(),), - desired_output_layouts=(Shard(1),), - ), - # replicate computation for the router - "moe.router.gate": NoParallel(), - # input Replicate, output Partial - "moe.experts": TensorParallel(output_layout=Partial()), - "moe.shared_expert": TensorParallel(output_layout=Partial()), - } + if enable_tp2ep: + moe_layer_plan = { + # input / output sharding on the seqlen dim + "moe": + PrepareModuleInputOutput( + input_layouts=(Shard(1), ), + desired_input_layouts=(Shard(1), ), + use_local_input=True, + output_layouts=(Shard(1), ), + desired_output_layouts=(Shard(1), ), + ), + # FIXME: The input is reshaped after sharded along + # the seqlen dimension. Should we use local tensors + # instead of Replicate? + "moe.router.gate": + NoParallel(), + # Given the tokens are not splitted evenly, + # we need to use local tensors for both input / output. + # After the manual all-to-all gather, the result is + # sharded along the seqlen dim. + "moe.experts": + ExpertParallel(), + "moe.shared_expert": + TensorParallel( + input_layouts=(Shard(1), None), + output_layout=Shard(1), + ), + } + else: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": + PrepareModuleInputOutput( + input_layouts=(Shard(1), ), + desired_input_layouts=(Replicate(), ), + use_local_input=True, + output_layouts=(Partial(), ), + desired_output_layouts=(Shard(1), ), + ), + # replicate computation for the router + "moe.router.gate": + NoParallel(), + # input Replicate, output Partial + "moe.experts": + TensorParallel(output_layout=Partial()), + "moe.shared_expert": + TensorParallel(output_layout=Partial()), + } parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index a07bf0f7b..faf995f54 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -7,6 +7,9 @@ import torch import torch.nn.functional as F from torch import nn +import torch.distributed as dist +from torch.distributed._functional_collectives import all_to_all_single_autograd +from torch.distributed.tensor import DTensor, Shard from .args import TransformerModelArgs @@ -31,6 +34,20 @@ def forward( x: torch.Tensor, num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, ) -> torch.Tensor: + if isinstance(self.w1, DTensor) and self.w1.placements == ( + Shard(0), ) and self.w1.device_mesh.size() > 1: + # expert parallel enabled + w1 = self.w1.to_local() + w2 = self.w2.to_local() + w3 = self.w3.to_local() + experts_per_rank = self.num_experts // self.w1.device_mesh.size() + else: + # expert parallel disabled + w1 = self.w1 + w2 = self.w2 + w3 = self.w3 + experts_per_rank = self.num_experts + # TODO: keeping this for loop implementation for comparison # and readability, will remove later if not self.use_grouped_mm: @@ -44,24 +61,27 @@ def forward( ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): - w1, w2, w3 = ( - self.w1[expert_idx], - self.w2[expert_idx], - self.w3[expert_idx], + expert_idx = expert_idx % experts_per_rank + current_w1, current_w2, current_w3 = ( + w1[expert_idx], + w2[expert_idx], + w3[expert_idx], ) - h = F.silu(torch.matmul(x_expert, w1)) - h = h * torch.matmul(x_expert, w3) - h = torch.matmul(h, w2) + h = F.silu(torch.matmul(x_expert, current_w1)) + h = h * torch.matmul(x_expert, current_w3) + h = torch.matmul(h, current_w2) # h shape (tokens_per_expert(varying), dim) out_experts_splits.append(h) out = torch.cat(out_experts_splits, dim=0) else: + bs, slen, dim = x.shape + x = x.reshape(1, bs * slen, dim) # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, self.w1)) - h = h * torch.bmm(x, self.w3) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, self.w2) - + out = torch.bmm(h, w2) + out = out.reshape(bs, slen, dim) return out # grouped mm implementation @@ -76,15 +96,15 @@ def forward( assert x.dim() == 2 else: offsets = None + bs, slen, dim = x.shape # fall back to regular bmm between 3D tensors - assert x.dim() == 3 + x = x.reshape(1, bs * slen, dim) - assert ( - x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 - ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) - h = h * torch._grouped_mm(x, self.w3, offs=offsets) - out = torch._grouped_mm(h, self.w2, offs=offsets) + assert (x.dtype == w1.dtype == w2.dtype == w3.dtype == + torch.bfloat16), "torch._grouped_mm only supports bf16 dtypes" + h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) + h = h * torch._grouped_mm(x, w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets) return out @@ -172,8 +192,14 @@ def init_weights(self, init_std: float): class MoE(nn.Module): - def __init__(self, model_args: TransformerModelArgs): + def __init__( + self, + model_args: TransformerModelArgs, + scoring_before_experts: bool = True, + ): super().__init__() + # compatibility with DeepSeek MoE + self.scoring_before_experts = scoring_before_experts dim = model_args.dim hidden_dim = 4 * model_args.dim ffn_dim_multiplier = model_args.ffn_dim_multiplier @@ -271,9 +297,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices, ) - routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to( - x.dtype - ) + + ep_size = 1 + ep_group = None + if isinstance( + self.experts.w1, DTensor) and self.experts.w1.placements == ( + Shard(0), ) and self.experts.w1.device_mesh.size() > 1: + # expert parallel enabled + ep_size = self.experts.w1.device_mesh.size() + ep_group = self.experts.w1.device_mesh.get_group() + assert num_local_tokens_per_expert is not None + with torch.no_grad(): + tokens_per_expert_group = num_local_tokens_per_expert.new_empty( + num_local_tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, + num_local_tokens_per_expert, + group=ep_group) + input_splits = num_local_tokens_per_expert.view(ep_size, + -1).sum(dim=1) + output_splits = tokens_per_expert_group.view(ep_size, + -1).sum(dim=1) + if self.training: + gathered_tokens = all_to_all_single_autograd( + routed_input, + output_splits.tolist(), + input_splits.tolist(), + ep_group, + ) + gathered_top_scores = all_to_all_single_autograd( + top_scores, + output_splits.tolist(), + input_splits.tolist(), + ep_group, + ) + else: + # TODO: unify with all_to_all_single_autograd after + # https://github.com/pytorch/pytorch/issues/154370 is resolved + gathered_num_tokens = output_splits.sum() + gathered_tokens = routed_input.new_empty( + (gathered_num_tokens, dim)) + dist.all_to_all_single( + gathered_tokens, + routed_input, + output_splits.tolist(), + input_splits.tolist(), + group=ep_group, + ) + gathered_top_scores = top_scores.new_empty( + gathered_num_tokens, ) + dist.all_to_all_single( + gathered_top_scores, + top_scores, + output_splits.tolist(), + input_splits.tolist(), + group=ep_group, + ) + else: + # expert parallel disabled + gathered_tokens = routed_input + gathered_top_scores = top_scores + tokens_per_expert_group = num_local_tokens_per_expert + + if self.scoring_before_experts: + gathered_tokens = (gathered_tokens.to(torch.float32) * + gathered_top_scores.reshape(-1, 1)).to(x.dtype) if self.use_grouped_mm: # NOTE: In order to use torch._grouped_mm, we need to make sure @@ -287,39 +374,73 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ALIGN_SIZE_M = 16 with torch.no_grad(): + experts_per_rank = self.experts.num_experts // ep_size ( permuted_indices, - num_local_tokens_per_expert, + tokens_per_expert_group, _, ) = generate_permute_indices( - num_local_tokens_per_expert, - self.experts.num_experts, - 1, - token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, + tokens_per_expert_group, + experts_per_rank, + ep_size, ALIGN_SIZE_M, ) - token_indices = torch.vstack( - (token_indices, token_indices.new_zeros((dim))) - ) - token_indices = token_indices[permuted_indices, :] - routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) - routed_input = routed_input[permuted_indices, :] + gathered_tokens_buffer = torch.vstack( + (gathered_tokens, gathered_tokens.new_zeros((dim)))) + buffer_shape = gathered_tokens_buffer.shape + gathered_tokens = gathered_tokens_buffer[permuted_indices, :] + + gathered_top_scores = torch.cat( + (gathered_top_scores, gathered_top_scores.new_zeros(1))) + gathered_top_scores = gathered_top_scores[permuted_indices] else: # NOTE: this would incur a synchronization between device and host - num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() + if tokens_per_expert_group is not None: + tokens_per_expert_group = tokens_per_expert_group.tolist() # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_local_tokens_per_expert) + routed_output = self.experts(gathered_tokens, tokens_per_expert_group) + if not self.scoring_before_experts: + routed_output = (routed_output * gathered_top_scores.reshape(-1, 1)).to(x.dtype) + + if self.use_grouped_mm: + gathered_tokens_buffer = routed_output.new_empty(buffer_shape) + gathered_tokens_buffer[permuted_indices, :] = routed_output + routed_output = gathered_tokens_buffer[:(buffer_shape[0] - 1), :] + + if ep_size > 1: + # expert parallel enabled, we need to gather the output + # from all experts + if self.training: + returned_tokens = all_to_all_single_autograd( + routed_output, + input_splits.tolist(), + output_splits.tolist(), + ep_group, + ) + else: + # TODO: unify with all_to_all_single_autograd after + # https://github.com/pytorch/pytorch/issues/154370 is resolved + returned_tokens = routed_output.new_empty( + (input_splits.sum(), dim)) + dist.all_to_all_single( + returned_tokens, + routed_output, + input_splits.tolist(), + output_splits.tolist(), + group=ep_group, + ) + else: + # expert parallel disabled, no need to gather + returned_tokens = routed_output # shared expert if self.shared_expert is not None: - out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim - ) + out = self.shared_expert(x).reshape(bs * slen, dim) else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = x.new_zeros((bs * slen, dim)) - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.scatter_add(dim=0, index=token_indices, src=returned_tokens) out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index bc48d3809..a1083f507 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -52,6 +52,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +enable_tp2ep = false [checkpoint] enable_checkpoint = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index f508968c8..1c6fd25db 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -46,6 +46,7 @@ pipeline_parallel_degree = 4 # pipeline_parallel_schedule = "interleaved1f1b" # pipeline_parallel_microbatches = 2 context_parallel_degree = 1 +enable_tp2ep = false [checkpoint] enable_checkpoint = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index c899dd508..ee76deafb 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -44,6 +44,7 @@ tensor_parallel_degree = 8 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +enable_tp2ep = false [checkpoint] enable_checkpoint = false From 18c5d1756dd1198a409b827e19740ab3029336c0 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 9 Jun 2025 04:29:16 +0000 Subject: [PATCH 2/6] refactor: move tp2ep communications into TokenDispatcher --- .../kernels/moe/token_dispatcher.py | 127 ++++++++++++++++++ .../llama4/infra/expert_parallel.py | 44 +++++- .../llama4/infra/parallelize_llama.py | 8 +- torchtitan/experiments/llama4/model/moe.py | 118 ++++------------ 4 files changed, 198 insertions(+), 99 deletions(-) create mode 100644 torchtitan/experiments/kernels/moe/token_dispatcher.py diff --git a/torchtitan/experiments/kernels/moe/token_dispatcher.py b/torchtitan/experiments/kernels/moe/token_dispatcher.py new file mode 100644 index 000000000..c91b97eca --- /dev/null +++ b/torchtitan/experiments/kernels/moe/token_dispatcher.py @@ -0,0 +1,127 @@ +from typing import Tuple +import torch +import torch.distributed as dist +from torch.distributed._functional_collectives import all_to_all_single_autograd + + +class DefaultTokenDispatcher: + + def __init__(self, num_experts: int, ep_size: int = 1): + self.num_experts = num_experts + self.ep_size = ep_size + self.experts_per_rank = num_experts // ep_size + + def token_permutation( + self, + routed_input: torch.Tensor, + top_scores: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor, + training: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, + torch.Tensor | None]: + return routed_input, top_scores, num_local_tokens_per_expert, None, None + + def token_unpermutation( + self, + routed_output: torch.Tensor, + input_splits: torch.Tensor | None = None, + output_splits: torch.Tensor | None = None, + training: bool = True, + ) -> torch.Tensor: + return routed_output + + +class TorchAllToAllTokenDispatcher(DefaultTokenDispatcher): + + def __init__( + self, + num_experts: int, + ep_size: int, + ep_group: torch.distributed.ProcessGroup, + ): + super().__init__(num_experts, ep_size) + self.ep_group = ep_group + + def token_permutation( + self, + routed_input: torch.Tensor, + top_scores: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor, + training: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, + torch.Tensor | None]: + dim = routed_input.shape[-1] + with torch.no_grad(): + tokens_per_expert_group = num_local_tokens_per_expert.new_empty( + num_local_tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, + num_local_tokens_per_expert, + group=self.ep_group) + input_splits = num_local_tokens_per_expert.view( + self.ep_size, -1).sum(dim=1) + output_splits = tokens_per_expert_group.view( + self.ep_size, -1).sum(dim=1) + if training: + gathered_tokens = all_to_all_single_autograd( + routed_input, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + gathered_top_scores = all_to_all_single_autograd( + top_scores, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + else: + # TODO: unify with all_to_all_single_autograd after + # https://github.com/pytorch/pytorch/issues/154370 is resolved + gathered_num_tokens = output_splits.sum() + gathered_tokens = routed_input.new_empty( + (gathered_num_tokens, dim)) + dist.all_to_all_single( + gathered_tokens, + routed_input, + output_splits.tolist(), + input_splits.tolist(), + group=self.ep_group, + ) + gathered_top_scores = top_scores.new_empty(gathered_num_tokens, ) + dist.all_to_all_single( + gathered_top_scores, + top_scores, + output_splits.tolist(), + input_splits.tolist(), + group=self.ep_group, + ) + return gathered_tokens, gathered_top_scores, tokens_per_expert_group, input_splits, output_splits + + def token_unpermutation( + self, + routed_output: torch.Tensor, + input_splits: torch.Tensor | None = None, + output_splits: torch.Tensor | None = None, + training: bool = True, + ) -> torch.Tensor: + dim = routed_output.shape[-1] + if training: + returned_tokens = all_to_all_single_autograd( + routed_output, + input_splits.tolist(), + output_splits.tolist(), + self.ep_group, + ) + else: + # TODO: unify with all_to_all_single_autograd after + # https://github.com/pytorch/pytorch/issues/154370 is resolved + returned_tokens = routed_output.new_empty( + (input_splits.sum(), dim)) + dist.all_to_all_single( + returned_tokens, + routed_output, + input_splits.tolist(), + output_splits.tolist(), + group=self.ep_group, + ) + return returned_tokens diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 37f4b083f..489709826 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -18,9 +18,17 @@ Replicate, Shard, ) -from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.parallel import ( + ParallelStyle, + PrepareModuleInputOutput, +) from torch.distributed.tensor.placement_types import Placement +from torchtitan.experiments.kernels.moe.token_dispatcher import ( + DefaultTokenDispatcher, + TorchAllToAllTokenDispatcher, +) + # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): @@ -156,7 +164,7 @@ def _prepare_input_fn(mod, inputs, device_mesh): inp, DTensor), "ExpertParallel expects local tensor inputs." return inputs - def _partition_fn(self, name, module, device_mesh): + def _partition_fn(self, name, module, device_mesh: DeviceMesh): # shard on the expert dimension for name, param in module.named_parameters(recurse=False): dist_param = nn.Parameter( @@ -177,3 +185,35 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: self._prepare_input_fn, self._prepare_output_fn, ) + + +class PrepareModuleInputOutputWithParams(PrepareModuleInputOutput): + + def __init__(self, *args, **kwargs): + self.enable_tp2ep = kwargs.pop("enable_tp2ep", False) + super().__init__(*args, **kwargs) + + def _partition_fn( + self, + name, + module, + device_mesh, + ): + for name, param in module.named_parameters(recurse=False): + dist_param = nn.Parameter( + distribute_tensor(param, device_mesh, [Replicate()])) + module.register_parameter(name, dist_param) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if hasattr(module, "token_dispatcher") and isinstance( + module.token_dispatcher, DefaultTokenDispatcher): + module.token_dispatcher = TorchAllToAllTokenDispatcher( + num_experts=module.num_experts, + ep_size=device_mesh.size(), + ep_group=device_mesh.get_group(), + ) + + super()._apply(module, device_mesh) + self._partition_fn("", module, device_mesh) + + return module diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py index 49b3bf48c..f952bcccc 100644 --- a/torchtitan/experiments/llama4/infra/parallelize_llama.py +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -154,13 +154,13 @@ def apply_moe_tp( from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( parallelize_module, - PrepareModuleInputOutput, ) from .expert_parallel import ( NoParallel, TensorParallel, ExpertParallel, + PrepareModuleInputOutputWithParams, ) for transformer_block in model.layers.values(): @@ -168,12 +168,13 @@ def apply_moe_tp( moe_layer_plan = { # input / output sharding on the seqlen dim "moe": - PrepareModuleInputOutput( + PrepareModuleInputOutputWithParams( input_layouts=(Shard(1), ), desired_input_layouts=(Shard(1), ), use_local_input=True, output_layouts=(Shard(1), ), desired_output_layouts=(Shard(1), ), + enable_tp2ep=enable_tp2ep, ), # FIXME: The input is reshaped after sharded along # the seqlen dimension. Should we use local tensors @@ -197,12 +198,13 @@ def apply_moe_tp( # input / output sharding on the seqlen dim # all-gather for input, reduce-scatter for output "moe": - PrepareModuleInputOutput( + PrepareModuleInputOutputWithParams( input_layouts=(Shard(1), ), desired_input_layouts=(Replicate(), ), use_local_input=True, output_layouts=(Partial(), ), desired_output_layouts=(Shard(1), ), + enable_tp2ep=enable_tp2ep, ), # replicate computation for the router "moe.router.gate": diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index faf995f54..d02689c28 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -7,9 +7,8 @@ import torch import torch.nn.functional as F from torch import nn -import torch.distributed as dist -from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import DTensor, Shard +from torchtitan.experiments.kernels.moe.token_dispatcher import DefaultTokenDispatcher from .args import TransformerModelArgs @@ -207,7 +206,7 @@ def __init__( if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) - num_experts = model_args.num_experts + self.num_experts = model_args.num_experts hidden_dim_denom = 1 if model_args.auto_scale_hidden_dim: @@ -221,11 +220,11 @@ def __init__( self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, - num_experts=num_experts, + num_experts=self.num_experts, use_grouped_mm=self.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( - dim=dim, num_experts=num_experts, top_k=model_args.top_k + dim=dim, num_experts=self.num_experts, top_k=model_args.top_k ) self.shared_expert = ( GroupedExperts( @@ -238,18 +237,20 @@ def __init__( else None ) + self.token_dispatcher = DefaultTokenDispatcher(self.num_experts) + # auxiliary-loss-free load balancing self.load_balance_coeff = model_args.load_balance_coeff # the fields below are defined even when load_balance_coeff is None # to make initialization and checkpointing code simpler self.register_buffer( "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), + torch.zeros(self.num_experts, dtype=torch.float32), persistent=True, ) self.register_buffer( "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), + torch.zeros(self.num_experts, dtype=torch.float32), persistent=True, ) @@ -298,65 +299,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: index=token_indices, ) - ep_size = 1 - ep_group = None - if isinstance( - self.experts.w1, DTensor) and self.experts.w1.placements == ( - Shard(0), ) and self.experts.w1.device_mesh.size() > 1: - # expert parallel enabled - ep_size = self.experts.w1.device_mesh.size() - ep_group = self.experts.w1.device_mesh.get_group() - assert num_local_tokens_per_expert is not None - with torch.no_grad(): - tokens_per_expert_group = num_local_tokens_per_expert.new_empty( - num_local_tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, - num_local_tokens_per_expert, - group=ep_group) - input_splits = num_local_tokens_per_expert.view(ep_size, - -1).sum(dim=1) - output_splits = tokens_per_expert_group.view(ep_size, - -1).sum(dim=1) - if self.training: - gathered_tokens = all_to_all_single_autograd( - routed_input, - output_splits.tolist(), - input_splits.tolist(), - ep_group, - ) - gathered_top_scores = all_to_all_single_autograd( - top_scores, - output_splits.tolist(), - input_splits.tolist(), - ep_group, - ) - else: - # TODO: unify with all_to_all_single_autograd after - # https://github.com/pytorch/pytorch/issues/154370 is resolved - gathered_num_tokens = output_splits.sum() - gathered_tokens = routed_input.new_empty( - (gathered_num_tokens, dim)) - dist.all_to_all_single( - gathered_tokens, - routed_input, - output_splits.tolist(), - input_splits.tolist(), - group=ep_group, - ) - gathered_top_scores = top_scores.new_empty( - gathered_num_tokens, ) - dist.all_to_all_single( - gathered_top_scores, - top_scores, - output_splits.tolist(), - input_splits.tolist(), - group=ep_group, - ) - else: - # expert parallel disabled - gathered_tokens = routed_input - gathered_top_scores = top_scores - tokens_per_expert_group = num_local_tokens_per_expert + ( + gathered_tokens, + gathered_top_scores, + tokens_per_expert_group, + input_splits, + output_splits, + ) = self.token_dispatcher.token_permutation( + routed_input, + top_scores, + num_local_tokens_per_expert, + self.training, + ) if self.scoring_before_experts: gathered_tokens = (gathered_tokens.to(torch.float32) * @@ -374,15 +328,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ALIGN_SIZE_M = 16 with torch.no_grad(): - experts_per_rank = self.experts.num_experts // ep_size ( permuted_indices, tokens_per_expert_group, _, ) = generate_permute_indices( tokens_per_expert_group, - experts_per_rank, - ep_size, + self.token_dispatcher.experts_per_rank, + self.token_dispatcher.ep_size, ALIGN_SIZE_M, ) gathered_tokens_buffer = torch.vstack( @@ -408,31 +361,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: gathered_tokens_buffer[permuted_indices, :] = routed_output routed_output = gathered_tokens_buffer[:(buffer_shape[0] - 1), :] - if ep_size > 1: - # expert parallel enabled, we need to gather the output - # from all experts - if self.training: - returned_tokens = all_to_all_single_autograd( - routed_output, - input_splits.tolist(), - output_splits.tolist(), - ep_group, - ) - else: - # TODO: unify with all_to_all_single_autograd after - # https://github.com/pytorch/pytorch/issues/154370 is resolved - returned_tokens = routed_output.new_empty( - (input_splits.sum(), dim)) - dist.all_to_all_single( - returned_tokens, - routed_output, - input_splits.tolist(), - output_splits.tolist(), - group=ep_group, - ) - else: - # expert parallel disabled, no need to gather - returned_tokens = routed_output + returned_tokens = self.token_dispatcher.token_unpermutation( + routed_output, input_splits, output_splits, self.training) # shared expert if self.shared_expert is not None: From 864ee31e31911378e47fadaf1b4297d223b75fce Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 9 Jun 2025 07:14:29 +0000 Subject: [PATCH 3/6] fix: torch.compile failure of TorchAllToAllTokenDispatcher --- .../llama4/infra/expert_parallel.py | 37 ------------------- .../llama4/infra/parallelize_llama.py | 8 ++-- torchtitan/experiments/llama4/model/moe.py | 15 +++++++- 3 files changed, 17 insertions(+), 43 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 489709826..64023b3d1 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -24,11 +24,6 @@ ) from torch.distributed.tensor.placement_types import Placement -from torchtitan.experiments.kernels.moe.token_dispatcher import ( - DefaultTokenDispatcher, - TorchAllToAllTokenDispatcher, -) - # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): @@ -185,35 +180,3 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: self._prepare_input_fn, self._prepare_output_fn, ) - - -class PrepareModuleInputOutputWithParams(PrepareModuleInputOutput): - - def __init__(self, *args, **kwargs): - self.enable_tp2ep = kwargs.pop("enable_tp2ep", False) - super().__init__(*args, **kwargs) - - def _partition_fn( - self, - name, - module, - device_mesh, - ): - for name, param in module.named_parameters(recurse=False): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Replicate()])) - module.register_parameter(name, dist_param) - - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - if hasattr(module, "token_dispatcher") and isinstance( - module.token_dispatcher, DefaultTokenDispatcher): - module.token_dispatcher = TorchAllToAllTokenDispatcher( - num_experts=module.num_experts, - ep_size=device_mesh.size(), - ep_group=device_mesh.get_group(), - ) - - super()._apply(module, device_mesh) - self._partition_fn("", module, device_mesh) - - return module diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py index f952bcccc..49b3bf48c 100644 --- a/torchtitan/experiments/llama4/infra/parallelize_llama.py +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -154,13 +154,13 @@ def apply_moe_tp( from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( parallelize_module, + PrepareModuleInputOutput, ) from .expert_parallel import ( NoParallel, TensorParallel, ExpertParallel, - PrepareModuleInputOutputWithParams, ) for transformer_block in model.layers.values(): @@ -168,13 +168,12 @@ def apply_moe_tp( moe_layer_plan = { # input / output sharding on the seqlen dim "moe": - PrepareModuleInputOutputWithParams( + PrepareModuleInputOutput( input_layouts=(Shard(1), ), desired_input_layouts=(Shard(1), ), use_local_input=True, output_layouts=(Shard(1), ), desired_output_layouts=(Shard(1), ), - enable_tp2ep=enable_tp2ep, ), # FIXME: The input is reshaped after sharded along # the seqlen dimension. Should we use local tensors @@ -198,13 +197,12 @@ def apply_moe_tp( # input / output sharding on the seqlen dim # all-gather for input, reduce-scatter for output "moe": - PrepareModuleInputOutputWithParams( + PrepareModuleInputOutput( input_layouts=(Shard(1), ), desired_input_layouts=(Replicate(), ), use_local_input=True, output_layouts=(Partial(), ), desired_output_layouts=(Shard(1), ), - enable_tp2ep=enable_tp2ep, ), # replicate computation for the router "moe.router.gate": diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index d02689c28..0216b4344 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -8,7 +8,10 @@ import torch.nn.functional as F from torch import nn from torch.distributed.tensor import DTensor, Shard -from torchtitan.experiments.kernels.moe.token_dispatcher import DefaultTokenDispatcher +from torchtitan.experiments.kernels.moe.token_dispatcher import ( + DefaultTokenDispatcher, + TorchAllToAllTokenDispatcher, +) from .args import TransformerModelArgs @@ -299,6 +302,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: index=token_indices, ) + # TODO: Find a better place to initialize the token dispatcher. + # I tried putting it in PrepareModuleInputOutputWithParams._apply, + # but caused torch compiling isses + if (isinstance(self.experts.w1, DTensor) and self.experts.w1.placements == (Shard(0),)): + self.token_dispatcher = TorchAllToAllTokenDispatcher( + num_experts=self.num_experts, + ep_size=self.experts.w1.device_mesh.size(), + ep_group=self.experts.w1.device_mesh.get_group(), + ) + ( gathered_tokens, gathered_top_scores, From b87aa1e9a5de88615f2d9cad1073890659c87586 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 13 Jun 2025 04:55:48 +0000 Subject: [PATCH 4/6] fix: expert bias update --- .../kernels/moe/token_dispatcher.py | 1 + torchtitan/experiments/llama4/model/moe.py | 51 ++++++++++++------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/torchtitan/experiments/kernels/moe/token_dispatcher.py b/torchtitan/experiments/kernels/moe/token_dispatcher.py index c91b97eca..289f3107b 100644 --- a/torchtitan/experiments/kernels/moe/token_dispatcher.py +++ b/torchtitan/experiments/kernels/moe/token_dispatcher.py @@ -10,6 +10,7 @@ def __init__(self, num_experts: int, ep_size: int = 1): self.num_experts = num_experts self.ep_size = ep_size self.experts_per_rank = num_experts // ep_size + self.ep_group = None def token_permutation( self, diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index 0216b4344..ce1b398ca 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -244,6 +244,7 @@ def __init__( # auxiliary-loss-free load balancing self.load_balance_coeff = model_args.load_balance_coeff + self.expert_bias_enabled = self.load_balance_coeff is not None and self.load_balance_coeff > 0 # the fields below are defined even when load_balance_coeff is None # to make initialization and checkpointing code simpler self.register_buffer( @@ -259,17 +260,18 @@ def __init__( # NOTE: forward hook, forward pre hook, or backward pre hook # would conflict with activation checkpointing - if self.load_balance_coeff is not None and self.load_balance_coeff > 0: + if self.expert_bias_enabled: self.register_full_backward_hook(self._update_expert_bias) def _update_expert_bias(self, *_): - expert_bias_delta = self.load_balance_coeff * torch.sign( - self.tokens_per_expert.mean() - self.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - self.expert_bias.add_(expert_bias_delta) + with torch.no_grad(): + expert_bias_delta = self.load_balance_coeff * torch.sign( + self.tokens_per_expert.mean() - self.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + self.expert_bias.add_(expert_bias_delta) - self.tokens_per_expert.zero_() + self.tokens_per_expert.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -289,8 +291,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert + # TODO: Find a better place to initialize the token dispatcher. + # I tried putting it in PrepareModuleInputOutputWithParams._apply, + # but caused torch compiling issues + if (isinstance(self.experts.w1, DTensor) + and self.experts.w1.placements == (Shard(0), )): + self.token_dispatcher = TorchAllToAllTokenDispatcher( + num_experts=self.num_experts, + ep_size=self.experts.w1.device_mesh.size(), + ep_group=self.experts.w1.device_mesh.get_group(), + ) + + # Prevent extra local tokens accumulation on evaluation or activation recomputation + if self.expert_bias_enabled and torch.is_grad_enabled(): + with torch.no_grad(): + num_local_tokens_per_expert_detached = num_local_tokens_per_expert.detach().clone() + if self.token_dispatcher.ep_group is not None: + # sum all num_local_tokens_per_expert from ep_mesh + torch.distributed.all_reduce( + num_local_tokens_per_expert_detached, + group=self.token_dispatcher.ep_group, + ) + # will be used to update the expert bias for load balancing + self.tokens_per_expert += num_local_tokens_per_expert_detached # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) @@ -302,16 +325,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: index=token_indices, ) - # TODO: Find a better place to initialize the token dispatcher. - # I tried putting it in PrepareModuleInputOutputWithParams._apply, - # but caused torch compiling isses - if (isinstance(self.experts.w1, DTensor) and self.experts.w1.placements == (Shard(0),)): - self.token_dispatcher = TorchAllToAllTokenDispatcher( - num_experts=self.num_experts, - ep_size=self.experts.w1.device_mesh.size(), - ep_group=self.experts.w1.device_mesh.get_group(), - ) - ( gathered_tokens, gathered_top_scores, From cc7a45c2afed779d0c711e3f858f1f76fe923fd8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 13 Jun 2025 06:27:50 +0000 Subject: [PATCH 5/6] chore: in-place add --- torchtitan/experiments/llama4/model/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index ce1b398ca..93ab0874e 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -313,7 +313,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.token_dispatcher.ep_group, ) # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert_detached + self.tokens_per_expert.add_(num_local_tokens_per_expert_detached) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) From cd1680f07dc5d28f7e2c78cc35411c093e698d25 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Jun 2025 06:16:23 +0000 Subject: [PATCH 6/6] fix: multiply scores in FP32 datatype --- torchtitan/experiments/llama4/model/moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index 93ab0874e..875293168 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -380,7 +380,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shape (bs*slen*top_k, dim) routed_output = self.experts(gathered_tokens, tokens_per_expert_group) if not self.scoring_before_experts: - routed_output = (routed_output * gathered_top_scores.reshape(-1, 1)).to(x.dtype) + routed_output = (routed_output.to(torch.float32) * + gathered_top_scores.reshape(-1, 1)).to(x.dtype) if self.use_grouped_mm: gathered_tokens_buffer = routed_output.new_empty(buffer_shape)