@@ -87,34 +87,40 @@ def run_test_distributed(self, fn, device_option=None, **kwargs):
8787 fn (** kwargs )
8888 workspace .ResetWorkspace ()
8989
90- def create_common_world (self , comm_rank , comm_size , tmpdir = None ):
90+ def create_common_world (self , comm_rank , comm_size , tmpdir = None , existing_cw = None ):
9191 store_handler = "store_handler"
9292
9393 # If REDIS_HOST is set, use RedisStoreHandler for rendezvous.
94- redis_host = os .getenv ("REDIS_HOST" )
95- redis_port = int (os .getenv ("REDIS_PORT" , 6379 ))
96- if redis_host is not None :
97- workspace .RunOperatorOnce (
98- core .CreateOperator (
99- "RedisStoreHandlerCreate" ,
100- [],
101- [store_handler ],
102- prefix = str (TestCase .test_counter ) + "/" ,
103- host = redis_host ,
104- port = redis_port ))
94+ if existing_cw is None :
95+ redis_host = os .getenv ("REDIS_HOST" )
96+ redis_port = int (os .getenv ("REDIS_PORT" , 6379 ))
97+ if redis_host is not None :
98+ workspace .RunOperatorOnce (
99+ core .CreateOperator (
100+ "RedisStoreHandlerCreate" ,
101+ [],
102+ [store_handler ],
103+ prefix = str (TestCase .test_counter ) + "/" ,
104+ host = redis_host ,
105+ port = redis_port ))
106+ else :
107+ workspace .RunOperatorOnce (
108+ core .CreateOperator (
109+ "FileStoreHandlerCreate" ,
110+ [],
111+ [store_handler ],
112+ path = tmpdir ))
113+ common_world = "common_world"
105114 else :
106- workspace .RunOperatorOnce (
107- core .CreateOperator (
108- "FileStoreHandlerCreate" ,
109- [],
110- [store_handler ],
111- path = tmpdir ))
115+ common_world = str (existing_cw ) + ".forked"
112116
113- common_world = "common_world"
117+ inputs = [store_handler ]
118+ if existing_cw is not None :
119+ inputs .append (existing_cw )
114120 workspace .RunOperatorOnce (
115121 core .CreateOperator (
116122 "CreateCommonWorld" ,
117- [ store_handler ] ,
123+ inputs ,
118124 [common_world ],
119125 size = comm_size ,
120126 rank = comm_rank ,
@@ -269,6 +275,45 @@ def _test_allreduce(self,
269275 for _tmp in range (4 ):
270276 workspace .RunNet (net .Name ())
271277
278+ def _test_allreduce_multicw (self ,
279+ comm_rank = None ,
280+ comm_size = None ,
281+ tmpdir = None
282+ ):
283+ _store_handler , common_world = self .create_common_world (
284+ comm_rank = comm_rank ,
285+ comm_size = comm_size ,
286+ tmpdir = tmpdir )
287+
288+ _ , common_world2 = self .create_common_world (
289+ comm_rank = comm_rank ,
290+ comm_size = comm_size ,
291+ tmpdir = tmpdir ,
292+ existing_cw = common_world )
293+
294+ blob_size = 1e4
295+ num_blobs = 4
296+
297+ for cw in [common_world , common_world2 ]:
298+ blobs = []
299+ for i in range (num_blobs ):
300+ blob = "blob_{}" .format (i )
301+ value = np .full (blob_size , (comm_rank * num_blobs ) + i , np .float32 )
302+ workspace .FeedBlob (blob , value )
303+ blobs .append (blob )
304+
305+ net = core .Net ("allreduce_multicw" )
306+ net .Allreduce (
307+ [cw ] + blobs ,
308+ blobs ,
309+ engine = op_engine )
310+
311+ workspace .RunNetOnce (net )
312+ for i in range (num_blobs ):
313+ np .testing .assert_array_equal (
314+ workspace .FetchBlob (blobs [i ]),
315+ (num_blobs * comm_size ) * (num_blobs * comm_size - 1 ) / 2 )
316+
272317 @given (comm_size = st .integers (min_value = 2 , max_value = 8 ),
273318 blob_size = st .integers (min_value = 1e3 , max_value = 1e6 ),
274319 num_blobs = st .integers (min_value = 1 , max_value = 4 ),
@@ -295,6 +340,21 @@ def test_allreduce(self, comm_size, blob_size, num_blobs, device_option,
295340 tmpdir = tmpdir ,
296341 use_float16 = use_float16 )
297342
343+ @given (device_option = st .sampled_from ([hu .cpu_do ]))
344+ def test_forked_cw (self , device_option ):
345+ TestCase .test_counter += 1
346+ if os .getenv ('COMM_RANK' ) is not None :
347+ self .run_test_distributed (
348+ self ._test_allreduce_multicw ,
349+ device_option = device_option )
350+ else :
351+ with TemporaryDirectory () as tmpdir :
352+ self .run_test_locally (
353+ self ._test_allreduce_multicw ,
354+ comm_size = 8 ,
355+ device_option = device_option ,
356+ tmpdir = tmpdir )
357+
298358 def _test_barrier (
299359 self ,
300360 comm_rank = None ,
0 commit comments