Skip to content

Commit c1e504e

Browse files
feginpytorchmergebot
authored andcommitted
[SymmMEM] Move AsyncTP tests to a seperate test class (pytorch#161820)
We move AsyncTP tests to a seperate test suite because 1) Async TP ops are not the core symmetric memory APIs, they are more like applications, 2) MultiProcContinuousTest will skip all the following tests if a test fails (we should fix this too). We still want to get the test signals for the core symmetric memory APIs when Async TP ops fail. Pull Request resolved: pytorch#161820 Approved by: https://github.com/kwen2501
1 parent 4ad9fbc commit c1e504e

File tree

1 file changed

+115
-98
lines changed

1 file changed

+115
-98
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 115 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,121 @@ def test_allow_overlapping_devices(self) -> None:
152152

153153
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "0"
154154

155+
@runOnRocmArch(MI300_ARCH)
156+
@skip_if_lt_x_gpu(2)
157+
@parametrize("symm_mem_input", [True, False])
158+
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
159+
self._init_process()
160+
161+
if symm_mem_input:
162+
t = _SymmetricMemory.empty_strided_p2p(
163+
size=(64, 64),
164+
stride=(64, 1),
165+
dtype=torch.float32,
166+
device=self.device,
167+
group_name="0",
168+
).fill_(self.rank)
169+
else:
170+
t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
171+
172+
res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
173+
res = torch.ops._c10d_functional.wait_tensor(res)
174+
self.assertEqual(res.shape, (64 * self.world_size, 64))
175+
176+
chunks = res.chunk(self.world_size)
177+
for r in range(self.world_size):
178+
self.assertTrue(chunks[r].eq(r).all())
179+
180+
@runOnRocmArch(MI300_ARCH)
181+
@skip_if_lt_x_gpu(2)
182+
@parametrize("reduce_op", ["sum", "avg"])
183+
@parametrize("symm_mem_input", [True, False])
184+
def test_low_contention_reduce_scatter(
185+
self, reduce_op: str, symm_mem_input: bool
186+
) -> None:
187+
self._init_process()
188+
189+
if symm_mem_input:
190+
t = _SymmetricMemory.empty_strided_p2p(
191+
size=(64, 64),
192+
stride=(64, 1),
193+
dtype=torch.float32,
194+
device=self.device,
195+
group_name="0",
196+
)
197+
else:
198+
t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
199+
200+
chunks = t.chunk(self.world_size)
201+
for r in range(self.world_size):
202+
chunks[r].fill_(r)
203+
204+
res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
205+
res = torch.ops._c10d_functional.wait_tensor(res)
206+
self.assertEqual(res.shape, (64 // self.world_size, 64))
207+
208+
if reduce_op == "sum":
209+
expect = self.rank * self.world_size
210+
elif reduce_op == "avg":
211+
expect = self.rank
212+
else:
213+
raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
214+
self.assertTrue(res.eq(expect).all())
215+
216+
@runOnRocmArch(MI300_ARCH)
217+
@skip_if_lt_x_gpu(4)
218+
def test_subgroup(self) -> None:
219+
self._init_process()
220+
221+
ranks = list(range(self.world_size))
222+
subgroup_0 = dist.new_group(ranks[: len(ranks) // 2])
223+
subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :])
224+
225+
world = dist.group.WORLD
226+
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
227+
228+
t = symm_mem.empty(64, device="cuda")
229+
symm_mem_world = symm_mem.rendezvous(t, group=world)
230+
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
231+
232+
self.assertEqual(symm_mem_world.world_size, world.size())
233+
self.assertEqual(symm_mem_world.rank, world.rank())
234+
self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2)
235+
self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size())
236+
237+
t.fill_(world.rank())
238+
symm_mem_world.barrier()
239+
240+
# Observe a peer buffer via the world group
241+
peer_rank = (world.rank() + 1) % world.size()
242+
buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32)
243+
self.assertTrue(buf.eq(peer_rank).all())
244+
245+
# Observe a peer buffer via the subgroup
246+
peer_rank = (subgroup.rank() + 1) % subgroup.size()
247+
buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32)
248+
if world.rank() < world.size() // 2:
249+
self.assertTrue(buf.eq(peer_rank).all())
250+
else:
251+
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
252+
253+
254+
# We move AsyncTP tests to a seperate test suite because 1) Async TP ops are not
255+
# the core symmetric memory APIs, they are more like applications, 2)
256+
# MultiProcContinuousTest will skip all the following tests if a test fails (
257+
# we should fix this too). We still want to get the test signals for the core
258+
# symmetric memory APIs when Async TP ops fail.
259+
@instantiate_parametrized_tests
260+
@requires_cuda_p2p_access()
261+
class AsyncTPTest(MultiProcContinuousTest):
262+
@property
263+
def device(self) -> torch.device:
264+
return torch.device(device_type, self.rank)
265+
266+
def _init_process(self):
267+
torch.cuda.set_device(self.device)
268+
torch.manual_seed(42 + self.rank)
269+
155270
@runOnRocmArch(MI300_ARCH)
156271
@skip_if_lt_x_gpu(2)
157272
@parametrize("gather_dim", [0, 1])
@@ -455,104 +570,6 @@ def test_optimal_layout(self, dim: int) -> None:
455570
self.assertTrue(x.movedim(dim, 0).is_contiguous())
456571
self.assertTrue(torch.allclose(x, t))
457572

458-
@runOnRocmArch(MI300_ARCH)
459-
@skip_if_lt_x_gpu(2)
460-
@parametrize("symm_mem_input", [True, False])
461-
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
462-
self._init_process()
463-
464-
if symm_mem_input:
465-
t = _SymmetricMemory.empty_strided_p2p(
466-
size=(64, 64),
467-
stride=(64, 1),
468-
dtype=torch.float32,
469-
device=self.device,
470-
group_name="0",
471-
).fill_(self.rank)
472-
else:
473-
t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
474-
475-
res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
476-
res = torch.ops._c10d_functional.wait_tensor(res)
477-
self.assertEqual(res.shape, (64 * self.world_size, 64))
478-
479-
chunks = res.chunk(self.world_size)
480-
for r in range(self.world_size):
481-
self.assertTrue(chunks[r].eq(r).all())
482-
483-
@runOnRocmArch(MI300_ARCH)
484-
@skip_if_lt_x_gpu(2)
485-
@parametrize("reduce_op", ["sum", "avg"])
486-
@parametrize("symm_mem_input", [True, False])
487-
def test_low_contention_reduce_scatter(
488-
self, reduce_op: str, symm_mem_input: bool
489-
) -> None:
490-
self._init_process()
491-
492-
if symm_mem_input:
493-
t = _SymmetricMemory.empty_strided_p2p(
494-
size=(64, 64),
495-
stride=(64, 1),
496-
dtype=torch.float32,
497-
device=self.device,
498-
group_name="0",
499-
)
500-
else:
501-
t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
502-
503-
chunks = t.chunk(self.world_size)
504-
for r in range(self.world_size):
505-
chunks[r].fill_(r)
506-
507-
res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
508-
res = torch.ops._c10d_functional.wait_tensor(res)
509-
self.assertEqual(res.shape, (64 // self.world_size, 64))
510-
511-
if reduce_op == "sum":
512-
expect = self.rank * self.world_size
513-
elif reduce_op == "avg":
514-
expect = self.rank
515-
else:
516-
raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
517-
self.assertTrue(res.eq(expect).all())
518-
519-
@runOnRocmArch(MI300_ARCH)
520-
@skip_if_lt_x_gpu(4)
521-
def test_subgroup(self) -> None:
522-
self._init_process()
523-
524-
ranks = list(range(self.world_size))
525-
subgroup_0 = dist.new_group(ranks[: len(ranks) // 2])
526-
subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :])
527-
528-
world = dist.group.WORLD
529-
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
530-
531-
t = symm_mem.empty(64, device="cuda")
532-
symm_mem_world = symm_mem.rendezvous(t, group=world)
533-
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
534-
535-
self.assertEqual(symm_mem_world.world_size, world.size())
536-
self.assertEqual(symm_mem_world.rank, world.rank())
537-
self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2)
538-
self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size())
539-
540-
t.fill_(world.rank())
541-
symm_mem_world.barrier()
542-
543-
# Observe a peer buffer via the world group
544-
peer_rank = (world.rank() + 1) % world.size()
545-
buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32)
546-
self.assertTrue(buf.eq(peer_rank).all())
547-
548-
# Observe a peer buffer via the subgroup
549-
peer_rank = (subgroup.rank() + 1) % subgroup.size()
550-
buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32)
551-
if world.rank() < world.size() // 2:
552-
self.assertTrue(buf.eq(peer_rank).all())
553-
else:
554-
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
555-
556573

557574
# [READ ME FIRST]
558575
# The `SymmMemEmptySetDeviceTest` suite parameterizes whether user sets the

0 commit comments

Comments
 (0)