Skip to content

Support ulysses for llm/mllm,dpo/sft #4085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 5, 2025
Merged
Prev Previous commit
Next Next commit
fix
  • Loading branch information
tastelikefeet committed May 2, 2025
commit 865ab193b14af40d99b890482b2346bac3030e17
125 changes: 123 additions & 2 deletions swift/trainers/sequence_parallel/ulysses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Any
from typing import Optional, Any, Tuple

import torch
import torch.distributed as dist
Expand All @@ -11,6 +11,127 @@
from swift.utils import get_dist_setting


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
This function generates the parameters required for `permute` and `reshape` operations,
which are used to process data before and after `all2all` communication.
"""
if batch_dim_idx == 0:
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
pre_all2all_permute_idx = (1, 0, 2, 3, 4)

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)

post_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
else:
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
pre_all2all_permute_idx = None

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = None
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]

return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape


def post_all2all(permute_idx, res_shape):
"""
Post-processing function for `all2all` communication.
"""

def post_func(input):
if permute_idx is not None:
input = input.permute(permute_idx).contiguous()
output = input.reshape(res_shape).contiguous()

return output

return post_func


def pre_all2all_fun(permute_idx, inp_shape, input):
"""
Pre-processing function for `all2all` communication.
"""
input_t = input.reshape(inp_shape).contiguous()
if permute_idx is not None:
input_t = input_t.permute(permute_idx).contiguous()
return input_t


def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
num_heads = input.shape[2]
if (num_heads % seq_world_size != 0 and not scatter_idx < 2):
raise
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
scatter_idx, batch_dim_idx, seq_world_size, input)

input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)

post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)

if async_op:
if type in ('dq', 'dk'):
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output.view(post_all2all_res_shape)

res = post_all2all_fun(output)
return res


class _SeqAllToAll(torch.autograd.Function):

@staticmethod
def forward(ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
batch_dim_idx: int,
stream=None,
handle=None,
type=None,
is_fwd=True) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
ctx.stream = stream
ctx.handle = handle
ctx.type = type
ctx.batch_dim_idx = batch_dim_idx
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
return res

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:

return (None,
_SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx,
ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None)


class Ulysses(SequenceParallel):

def __init__(self):
Expand Down Expand Up @@ -60,7 +181,7 @@ def forward(self,
batch_dim_idx: int,
*args: Any,
**kwargs) -> Tensor:
from deepspeed.sequence.layer import _SeqAllToAll
# from deepspeed.sequence.layer import _SeqAllToAll
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx,
None,
self.overlap_handles, 'q')
Expand Down
2 changes: 1 addition & 1 deletion swift/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Avoid circular reference
def _is_local_master():
local_rank = int(os.getenv('LOCAL_RANK', -1))
return local_rank in {-1, 0}
return True


init_loggers = {}
Expand Down