-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] Add tensor_fusion and overlap in auto dy sharding #72551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
self.enable_tensor_fusion = ( | ||
os.getenv("FLAGS_enable_tensor_fusion") == '1' | ||
) | ||
self.enable_sharding_overlap = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharding_overlap will not be set through FLAGS in the future and better to configure it with Strategy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, will fix it later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
New features
Description
param.main_grad
replaces the oldmaster_grad
in auto dy.param.main_grad
will use inplaceadd_
to save or cast grad to fp32 and store them inparam.main_grad
.export Flags_enable_inplace_master_grad=1
.tensor_fusion
groups params and grads into continuousparam_storage
andgrad_storage
.grad_storage
is used for grad'sreduce_scatter
comm.param_storage
is used for param'sall_gather
comm.param_storage
andgrad_storage
usingview_slice
.grad_chip
requires callall_reduce
manually to collectglobal_norm_var
.export FLAGS_enable_tensor_fusion=1
.reduce_scatter
comm for grads with grad computation in bwd.all_gather
comm for params with opt computation.export FLAGS_enable_tensor_fusion=1
.Note: non-uniform
tensor_fusion
changes the order ofadd
ingrad_chip
, introducing some loss diff.Convergence results on llama7b, 1NC8, sharding8, 50,000 steps.
【TODO】Add strategy config in auto-dy, like hand-dy (feelt.init(strategy)) and auto-static (to_static(strategy)).
Pcard-70448