Skip to content

Commit bda8b4e

Browse files
Aapo Kyrolafacebook-github-bot
authored andcommitted
enable CreateCommonWorld to bootstrap from existing common world
Summary: Use romain-intel's ContextFactory to create common worlds from existing common worlds, thus bypassing KV store completely. Changed data_parallel_model to automatically find if there is already a CW we can work. CreateCommonWorldOp takes optional second parameter, which is existing CW. Reviewed By: andrewwdye Differential Revision: D5494956 fbshipit-source-id: 5f7a840bcd5fe4ea756fafeacc746bc2cf5078b0
1 parent 7811796 commit bda8b4e

File tree

4 files changed

+119
-27
lines changed

4 files changed

+119
-27
lines changed

caffe2/contrib/gloo/common_world_ops.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,24 @@ class CreateCommonWorld final : public Operator<Context> {
5252

5353
try {
5454
// Create context and connect everyone to everyone
55-
auto context =
56-
std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
57-
context->connectFullMesh(store, device_);
55+
std::shared_ptr<::gloo::Context> context;
56+
57+
if (InputSize() == 1) {
58+
auto new_context =
59+
std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
60+
new_context->connectFullMesh(store, device_);
61+
context = std::move(new_context);
62+
} else {
63+
VLOG(1) << "Creating new common world by forking existing one.";
64+
auto backingCommonWorld =
65+
OperatorBase::Input<std::shared_ptr<::gloo::Context>>(EXISTING_CW);
66+
67+
// Confirm the backing common world is compatible with this op
68+
CAFFE_ENFORCE_EQ(rank_, backingCommonWorld->rank);
69+
CAFFE_ENFORCE_EQ(size_, backingCommonWorld->size);
70+
::gloo::rendezvous::ContextFactory factory(backingCommonWorld);
71+
context = factory.makeContext(device_);
72+
}
5873

5974
// Switch pairs to synchronous mode if configured to do so
6075
if (sync_) {
@@ -100,7 +115,7 @@ class CreateCommonWorld final : public Operator<Context> {
100115
Workspace* ws_;
101116
std::string status_blob_;
102117

103-
INPUT_TAGS(STORE_HANDLER);
118+
INPUT_TAGS(STORE_HANDLER, EXISTING_CW);
104119
OUTPUT_TAGS(COMM);
105120
};
106121

caffe2/contrib/gloo/gloo_test.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,34 +87,40 @@ def run_test_distributed(self, fn, device_option=None, **kwargs):
8787
fn(**kwargs)
8888
workspace.ResetWorkspace()
8989

90-
def create_common_world(self, comm_rank, comm_size, tmpdir=None):
90+
def create_common_world(self, comm_rank, comm_size, tmpdir=None, existing_cw=None):
9191
store_handler = "store_handler"
9292

9393
# If REDIS_HOST is set, use RedisStoreHandler for rendezvous.
94-
redis_host = os.getenv("REDIS_HOST")
95-
redis_port = int(os.getenv("REDIS_PORT", 6379))
96-
if redis_host is not None:
97-
workspace.RunOperatorOnce(
98-
core.CreateOperator(
99-
"RedisStoreHandlerCreate",
100-
[],
101-
[store_handler],
102-
prefix=str(TestCase.test_counter) + "/",
103-
host=redis_host,
104-
port=redis_port))
94+
if existing_cw is None:
95+
redis_host = os.getenv("REDIS_HOST")
96+
redis_port = int(os.getenv("REDIS_PORT", 6379))
97+
if redis_host is not None:
98+
workspace.RunOperatorOnce(
99+
core.CreateOperator(
100+
"RedisStoreHandlerCreate",
101+
[],
102+
[store_handler],
103+
prefix=str(TestCase.test_counter) + "/",
104+
host=redis_host,
105+
port=redis_port))
106+
else:
107+
workspace.RunOperatorOnce(
108+
core.CreateOperator(
109+
"FileStoreHandlerCreate",
110+
[],
111+
[store_handler],
112+
path=tmpdir))
113+
common_world = "common_world"
105114
else:
106-
workspace.RunOperatorOnce(
107-
core.CreateOperator(
108-
"FileStoreHandlerCreate",
109-
[],
110-
[store_handler],
111-
path=tmpdir))
115+
common_world = str(existing_cw) + ".forked"
112116

113-
common_world = "common_world"
117+
inputs = [store_handler]
118+
if existing_cw is not None:
119+
inputs.append(existing_cw)
114120
workspace.RunOperatorOnce(
115121
core.CreateOperator(
116122
"CreateCommonWorld",
117-
[store_handler],
123+
inputs,
118124
[common_world],
119125
size=comm_size,
120126
rank=comm_rank,
@@ -269,6 +275,45 @@ def _test_allreduce(self,
269275
for _tmp in range(4):
270276
workspace.RunNet(net.Name())
271277

278+
def _test_allreduce_multicw(self,
279+
comm_rank=None,
280+
comm_size=None,
281+
tmpdir=None
282+
):
283+
_store_handler, common_world = self.create_common_world(
284+
comm_rank=comm_rank,
285+
comm_size=comm_size,
286+
tmpdir=tmpdir)
287+
288+
_, common_world2 = self.create_common_world(
289+
comm_rank=comm_rank,
290+
comm_size=comm_size,
291+
tmpdir=tmpdir,
292+
existing_cw=common_world)
293+
294+
blob_size = 1e4
295+
num_blobs = 4
296+
297+
for cw in [common_world, common_world2]:
298+
blobs = []
299+
for i in range(num_blobs):
300+
blob = "blob_{}".format(i)
301+
value = np.full(blob_size, (comm_rank * num_blobs) + i, np.float32)
302+
workspace.FeedBlob(blob, value)
303+
blobs.append(blob)
304+
305+
net = core.Net("allreduce_multicw")
306+
net.Allreduce(
307+
[cw] + blobs,
308+
blobs,
309+
engine=op_engine)
310+
311+
workspace.RunNetOnce(net)
312+
for i in range(num_blobs):
313+
np.testing.assert_array_equal(
314+
workspace.FetchBlob(blobs[i]),
315+
(num_blobs * comm_size) * (num_blobs * comm_size - 1) / 2)
316+
272317
@given(comm_size=st.integers(min_value=2, max_value=8),
273318
blob_size=st.integers(min_value=1e3, max_value=1e6),
274319
num_blobs=st.integers(min_value=1, max_value=4),
@@ -295,6 +340,21 @@ def test_allreduce(self, comm_size, blob_size, num_blobs, device_option,
295340
tmpdir=tmpdir,
296341
use_float16=use_float16)
297342

343+
@given(device_option=st.sampled_from([hu.cpu_do]))
344+
def test_forked_cw(self, device_option):
345+
TestCase.test_counter += 1
346+
if os.getenv('COMM_RANK') is not None:
347+
self.run_test_distributed(
348+
self._test_allreduce_multicw,
349+
device_option=device_option)
350+
else:
351+
with TemporaryDirectory() as tmpdir:
352+
self.run_test_locally(
353+
self._test_allreduce_multicw,
354+
comm_size=8,
355+
device_option=device_option,
356+
tmpdir=tmpdir)
357+
298358
def _test_barrier(
299359
self,
300360
comm_rank=None,

caffe2/operators/communicator_op.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
namespace caffe2 {
55

66
OPERATOR_SCHEMA(CreateCommonWorld)
7-
.NumInputs(0, 1)
7+
.NumInputs(0, 2)
88
.NumOutputs(1)
99
.SetDoc(R"DOC(
1010
Creates a common world for communication operators.
1111
)DOC")
1212
.Input(0, "kv_handler", "Key/value handler for rendezvous (optional).")
13+
.Input(
14+
1,
15+
"existing_common_world",
16+
"existing c-w that can be used to fork new one faster (optional).")
1317
.Output(0, "comm_world", "A common world for collective operations.")
1418
.Arg("size", "(int) size of the common world.")
1519
.Arg("rank", "(int) rank of this node in the common world.");

caffe2/python/data_parallel_model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def Synchronize(model, timeout_sec=30):
544544
barrier_instance += 1
545545
barrier_net = core.Net("sync_barrier_net_" + str(instance))
546546
comm_world = barrier_net.CreateCommonWorld(
547-
model._rendezvous['kv_handler'],
547+
[model._rendezvous['kv_handler']] +
548+
_GetCommonWorldToFork(model.param_init_net),
548549
"sync_barrier_cw_" + str(instance),
549550
name="sync_barrier_cw_op_" + str(instance),
550551
size=model._rendezvous['num_shards'],
@@ -931,7 +932,8 @@ def get_control_and_context(self, control_output_blob):
931932
current_slot = self.counter % self.max_concurrent_context
932933
if len(self.common_worlds) < self.max_concurrent_context:
933934
common_world = self.param_init_net.CreateCommonWorld(
934-
self.rendezvous['kv_handler'],
935+
[self.rendezvous['kv_handler']] +
936+
_GetCommonWorldToFork(self.param_init_net),
935937
"{}_{}_cw".format(self.name, current_slot),
936938
name="{}_{}_cw_op".format(self.name, current_slot),
937939
size=self.rendezvous['num_shards'],
@@ -1366,6 +1368,17 @@ def OptimizeGradientMemory(model,
13661368
)
13671369

13681370

1371+
def _GetCommonWorldToFork(param_init_net):
1372+
'''
1373+
We can fork common worlds from existing ones. So inspect the param_init_net
1374+
for an already created commonworld
1375+
'''
1376+
for op in param_init_net.Proto().op:
1377+
if op.type == "CreateCommonWorld":
1378+
return [op.output[0]]
1379+
return []
1380+
1381+
13691382
barrier_instance = 0
13701383

13711384

0 commit comments

Comments
 (0)