@@ -152,6 +152,121 @@ def test_allow_overlapping_devices(self) -> None:
152152
153153 os .environ ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES" ] = "0"
154154
155+ @runOnRocmArch (MI300_ARCH )
156+ @skip_if_lt_x_gpu (2 )
157+ @parametrize ("symm_mem_input" , [True , False ])
158+ def test_low_contention_all_gather (self , symm_mem_input : bool ) -> None :
159+ self ._init_process ()
160+
161+ if symm_mem_input :
162+ t = _SymmetricMemory .empty_strided_p2p (
163+ size = (64 , 64 ),
164+ stride = (64 , 1 ),
165+ dtype = torch .float32 ,
166+ device = self .device ,
167+ group_name = "0" ,
168+ ).fill_ (self .rank )
169+ else :
170+ t = torch .full ((64 , 64 ), self .rank , dtype = torch .float32 , device = self .device )
171+
172+ res = torch .ops .symm_mem ._low_contention_all_gather (t , "0" )
173+ res = torch .ops ._c10d_functional .wait_tensor (res )
174+ self .assertEqual (res .shape , (64 * self .world_size , 64 ))
175+
176+ chunks = res .chunk (self .world_size )
177+ for r in range (self .world_size ):
178+ self .assertTrue (chunks [r ].eq (r ).all ())
179+
180+ @runOnRocmArch (MI300_ARCH )
181+ @skip_if_lt_x_gpu (2 )
182+ @parametrize ("reduce_op" , ["sum" , "avg" ])
183+ @parametrize ("symm_mem_input" , [True , False ])
184+ def test_low_contention_reduce_scatter (
185+ self , reduce_op : str , symm_mem_input : bool
186+ ) -> None :
187+ self ._init_process ()
188+
189+ if symm_mem_input :
190+ t = _SymmetricMemory .empty_strided_p2p (
191+ size = (64 , 64 ),
192+ stride = (64 , 1 ),
193+ dtype = torch .float32 ,
194+ device = self .device ,
195+ group_name = "0" ,
196+ )
197+ else :
198+ t = torch .empty ((64 , 64 ), dtype = torch .float32 , device = self .device )
199+
200+ chunks = t .chunk (self .world_size )
201+ for r in range (self .world_size ):
202+ chunks [r ].fill_ (r )
203+
204+ res = torch .ops .symm_mem ._low_contention_reduce_scatter (t , reduce_op , "0" )
205+ res = torch .ops ._c10d_functional .wait_tensor (res )
206+ self .assertEqual (res .shape , (64 // self .world_size , 64 ))
207+
208+ if reduce_op == "sum" :
209+ expect = self .rank * self .world_size
210+ elif reduce_op == "avg" :
211+ expect = self .rank
212+ else :
213+ raise AssertionError (f"Unexpected reduce_op: { reduce_op } " )
214+ self .assertTrue (res .eq (expect ).all ())
215+
216+ @runOnRocmArch (MI300_ARCH )
217+ @skip_if_lt_x_gpu (4 )
218+ def test_subgroup (self ) -> None :
219+ self ._init_process ()
220+
221+ ranks = list (range (self .world_size ))
222+ subgroup_0 = dist .new_group (ranks [: len (ranks ) // 2 ])
223+ subgroup_1 = dist .new_group (ranks [len (ranks ) // 2 :])
224+
225+ world = dist .group .WORLD
226+ subgroup = subgroup_0 if world .rank () < world .size () // 2 else subgroup_1
227+
228+ t = symm_mem .empty (64 , device = "cuda" )
229+ symm_mem_world = symm_mem .rendezvous (t , group = world )
230+ symm_mem_subgroup = symm_mem .rendezvous (t , group = subgroup )
231+
232+ self .assertEqual (symm_mem_world .world_size , world .size ())
233+ self .assertEqual (symm_mem_world .rank , world .rank ())
234+ self .assertEqual (symm_mem_subgroup .world_size , world .size () // 2 )
235+ self .assertEqual (symm_mem_subgroup .rank , world .rank () % subgroup .size ())
236+
237+ t .fill_ (world .rank ())
238+ symm_mem_world .barrier ()
239+
240+ # Observe a peer buffer via the world group
241+ peer_rank = (world .rank () + 1 ) % world .size ()
242+ buf = symm_mem_world .get_buffer (peer_rank , (64 ,), torch .float32 )
243+ self .assertTrue (buf .eq (peer_rank ).all ())
244+
245+ # Observe a peer buffer via the subgroup
246+ peer_rank = (subgroup .rank () + 1 ) % subgroup .size ()
247+ buf = symm_mem_subgroup .get_buffer (peer_rank , (64 ,), torch .float32 )
248+ if world .rank () < world .size () // 2 :
249+ self .assertTrue (buf .eq (peer_rank ).all ())
250+ else :
251+ self .assertTrue (buf .eq (peer_rank + world .size () // 2 ).all ())
252+
253+
254+ # We move AsyncTP tests to a seperate test suite because 1) Async TP ops are not
255+ # the core symmetric memory APIs, they are more like applications, 2)
256+ # MultiProcContinuousTest will skip all the following tests if a test fails (
257+ # we should fix this too). We still want to get the test signals for the core
258+ # symmetric memory APIs when Async TP ops fail.
259+ @instantiate_parametrized_tests
260+ @requires_cuda_p2p_access ()
261+ class AsyncTPTest (MultiProcContinuousTest ):
262+ @property
263+ def device (self ) -> torch .device :
264+ return torch .device (device_type , self .rank )
265+
266+ def _init_process (self ):
267+ torch .cuda .set_device (self .device )
268+ torch .manual_seed (42 + self .rank )
269+
155270 @runOnRocmArch (MI300_ARCH )
156271 @skip_if_lt_x_gpu (2 )
157272 @parametrize ("gather_dim" , [0 , 1 ])
@@ -455,104 +570,6 @@ def test_optimal_layout(self, dim: int) -> None:
455570 self .assertTrue (x .movedim (dim , 0 ).is_contiguous ())
456571 self .assertTrue (torch .allclose (x , t ))
457572
458- @runOnRocmArch (MI300_ARCH )
459- @skip_if_lt_x_gpu (2 )
460- @parametrize ("symm_mem_input" , [True , False ])
461- def test_low_contention_all_gather (self , symm_mem_input : bool ) -> None :
462- self ._init_process ()
463-
464- if symm_mem_input :
465- t = _SymmetricMemory .empty_strided_p2p (
466- size = (64 , 64 ),
467- stride = (64 , 1 ),
468- dtype = torch .float32 ,
469- device = self .device ,
470- group_name = "0" ,
471- ).fill_ (self .rank )
472- else :
473- t = torch .full ((64 , 64 ), self .rank , dtype = torch .float32 , device = self .device )
474-
475- res = torch .ops .symm_mem ._low_contention_all_gather (t , "0" )
476- res = torch .ops ._c10d_functional .wait_tensor (res )
477- self .assertEqual (res .shape , (64 * self .world_size , 64 ))
478-
479- chunks = res .chunk (self .world_size )
480- for r in range (self .world_size ):
481- self .assertTrue (chunks [r ].eq (r ).all ())
482-
483- @runOnRocmArch (MI300_ARCH )
484- @skip_if_lt_x_gpu (2 )
485- @parametrize ("reduce_op" , ["sum" , "avg" ])
486- @parametrize ("symm_mem_input" , [True , False ])
487- def test_low_contention_reduce_scatter (
488- self , reduce_op : str , symm_mem_input : bool
489- ) -> None :
490- self ._init_process ()
491-
492- if symm_mem_input :
493- t = _SymmetricMemory .empty_strided_p2p (
494- size = (64 , 64 ),
495- stride = (64 , 1 ),
496- dtype = torch .float32 ,
497- device = self .device ,
498- group_name = "0" ,
499- )
500- else :
501- t = torch .empty ((64 , 64 ), dtype = torch .float32 , device = self .device )
502-
503- chunks = t .chunk (self .world_size )
504- for r in range (self .world_size ):
505- chunks [r ].fill_ (r )
506-
507- res = torch .ops .symm_mem ._low_contention_reduce_scatter (t , reduce_op , "0" )
508- res = torch .ops ._c10d_functional .wait_tensor (res )
509- self .assertEqual (res .shape , (64 // self .world_size , 64 ))
510-
511- if reduce_op == "sum" :
512- expect = self .rank * self .world_size
513- elif reduce_op == "avg" :
514- expect = self .rank
515- else :
516- raise AssertionError (f"Unexpected reduce_op: { reduce_op } " )
517- self .assertTrue (res .eq (expect ).all ())
518-
519- @runOnRocmArch (MI300_ARCH )
520- @skip_if_lt_x_gpu (4 )
521- def test_subgroup (self ) -> None :
522- self ._init_process ()
523-
524- ranks = list (range (self .world_size ))
525- subgroup_0 = dist .new_group (ranks [: len (ranks ) // 2 ])
526- subgroup_1 = dist .new_group (ranks [len (ranks ) // 2 :])
527-
528- world = dist .group .WORLD
529- subgroup = subgroup_0 if world .rank () < world .size () // 2 else subgroup_1
530-
531- t = symm_mem .empty (64 , device = "cuda" )
532- symm_mem_world = symm_mem .rendezvous (t , group = world )
533- symm_mem_subgroup = symm_mem .rendezvous (t , group = subgroup )
534-
535- self .assertEqual (symm_mem_world .world_size , world .size ())
536- self .assertEqual (symm_mem_world .rank , world .rank ())
537- self .assertEqual (symm_mem_subgroup .world_size , world .size () // 2 )
538- self .assertEqual (symm_mem_subgroup .rank , world .rank () % subgroup .size ())
539-
540- t .fill_ (world .rank ())
541- symm_mem_world .barrier ()
542-
543- # Observe a peer buffer via the world group
544- peer_rank = (world .rank () + 1 ) % world .size ()
545- buf = symm_mem_world .get_buffer (peer_rank , (64 ,), torch .float32 )
546- self .assertTrue (buf .eq (peer_rank ).all ())
547-
548- # Observe a peer buffer via the subgroup
549- peer_rank = (subgroup .rank () + 1 ) % subgroup .size ()
550- buf = symm_mem_subgroup .get_buffer (peer_rank , (64 ,), torch .float32 )
551- if world .rank () < world .size () // 2 :
552- self .assertTrue (buf .eq (peer_rank ).all ())
553- else :
554- self .assertTrue (buf .eq (peer_rank + world .size () // 2 ).all ())
555-
556573
557574# [READ ME FIRST]
558575# The `SymmMemEmptySetDeviceTest` suite parameterizes whether user sets the
0 commit comments