Skip to content

Commit bc51091

Browse files
feginpytorchmergebot
authored andcommitted
Only make wait_tensor as a side_effect op (pytorch#132341)
Summary: pytorch#131023 add all the collective ops to the side effect list. But we should only make wait_tensor as a side_effect op because all collective ops should have a corresponding wait_tensor. We should switch to use high_order effect token. Pull Request resolved: pytorch#132341 Approved by: https://github.com/yf225
1 parent ef426d5 commit bc51091

File tree

2 files changed

+22
-59
lines changed

2 files changed

+22
-59
lines changed

test/distributed/test_functional_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,28 @@ def allreduce(t, pg):
587587
)
588588
allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
589589

590+
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
591+
@requires_nccl()
592+
@with_comms()
593+
def test_tracing_with_dce_code(self):
594+
if self.world_size > 2:
595+
return
596+
597+
def func(batch, group, rank):
598+
ret = ft_c.permute_tensor(batch, [1, 0], group)
599+
if hasattr(ret, "wait"):
600+
ret = ret.wait()
601+
if rank == 0:
602+
return ret
603+
else:
604+
return batch * 5
605+
606+
compiled_func = torch.compile(func)
607+
ret = compiled_func(
608+
torch.ones((100,), device="cuda"), self.process_group, self.rank
609+
)
610+
dist.barrier()
611+
590612

591613
class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
592614
@property

torch/distributed/_functional_collectives.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -963,66 +963,7 @@ def _reduce_scatter_tensor_coalesced_native_meta(
963963

964964
# mark these ops has side effect so that they won't be removed by DCE
965965
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
966-
torch.fx.node.has_side_effect(
967-
torch.ops._c10d_functional.all_gather_into_tensor_out.default
968-
)
969-
torch.fx.node.has_side_effect(
970-
torch.ops._c10d_functional.all_gather_into_tensor.default
971-
)
972-
torch.fx.node.has_side_effect(
973-
torch.ops._c10d_functional.all_gather_into_tensor_coalesced.default
974-
)
975-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_reduce.default)
976-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_reduce_.default)
977-
torch.fx.node.has_side_effect(
978-
torch.ops._c10d_functional.all_reduce_coalesced.default
979-
)
980-
torch.fx.node.has_side_effect(
981-
torch.ops._c10d_functional.all_reduce_coalesced_.default
982-
)
983-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_to_all_single.default)
984-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.broadcast.default)
985-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.broadcast_.default)
986-
torch.fx.node.has_side_effect(
987-
torch.ops._c10d_functional.reduce_scatter_tensor.default
988-
)
989-
torch.fx.node.has_side_effect(
990-
torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default
991-
)
992-
torch.fx.node.has_side_effect(
993-
torch.ops._c10d_functional_autograd.all_to_all_single.default
994-
)
995-
torch.fx.node.has_side_effect(
996-
torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default
997-
)
998-
torch.fx.node.has_side_effect(
999-
torch.ops._c10d_functional_autograd.all_gather_into_tensor.default
1000-
)
1001-
# also the no-overload version
1002966
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
1003-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_gather_into_tensor_out)
1004-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_gather_into_tensor)
1005-
torch.fx.node.has_side_effect(
1006-
torch.ops._c10d_functional.all_gather_into_tensor_coalesced
1007-
)
1008-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_reduce)
1009-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_reduce_)
1010-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_reduce_coalesced)
1011-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_reduce_coalesced_)
1012-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.all_to_all_single)
1013-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.broadcast)
1014-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.broadcast_)
1015-
torch.fx.node.has_side_effect(torch.ops._c10d_functional.reduce_scatter_tensor)
1016-
torch.fx.node.has_side_effect(
1017-
torch.ops._c10d_functional.reduce_scatter_tensor_coalesced
1018-
)
1019-
torch.fx.node.has_side_effect(torch.ops._c10d_functional_autograd.all_to_all_single)
1020-
torch.fx.node.has_side_effect(
1021-
torch.ops._c10d_functional_autograd.reduce_scatter_tensor
1022-
)
1023-
torch.fx.node.has_side_effect(
1024-
torch.ops._c10d_functional_autograd.all_gather_into_tensor
1025-
)
1026967

1027968
# Register legacy ops for backward compatibility
1028969
# TODO(yifu): remove these in functional collective beta release

0 commit comments

Comments
 (0)