Skip to content

Commit 06d564b

Browse files
jysohn23davidel
andauthored
Cherrypick fix for DLRM real dataset crash (pytorch#2409)
Co-authored-by: Davide Libenzi <[email protected]>
1 parent ade6927 commit 06d564b

File tree

5 files changed

+65
-37
lines changed

5 files changed

+65
-37
lines changed

third_party/xla_client/multi_wait.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void MultiWait::Done() {
1111
{
1212
std::lock_guard<std::mutex> lock(mutex_);
1313
completed_count_ += 1;
14-
notify = completed_count_ >= count_;
14+
notify = completed_count_ == count_;
1515
}
1616
if (notify) {
1717
cv_.notify_all();
@@ -45,17 +45,27 @@ void MultiWait::Reset(size_t count) {
4545
}
4646

4747
std::function<void()> MultiWait::Completer(std::function<void()> func) {
48-
auto completer = [this, func = std::move(func)]() {
49-
try {
50-
func();
51-
} catch (...) {
52-
std::lock_guard<std::mutex> lock(mutex_);
53-
exptr_ = std::current_exception();
54-
}
55-
Done();
48+
auto completer = [this, func = std::move(func)]() { Complete(func); };
49+
return completer;
50+
}
51+
52+
std::function<void()> MultiWait::Completer(std::shared_ptr<MultiWait> mwait,
53+
std::function<void()> func) {
54+
auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
55+
mwait->Complete(func);
5656
};
5757
return completer;
5858
}
5959

60+
void MultiWait::Complete(const std::function<void()>& func) {
61+
try {
62+
func();
63+
} catch (...) {
64+
std::lock_guard<std::mutex> lock(mutex_);
65+
exptr_ = std::current_exception();
66+
}
67+
Done();
68+
}
69+
6070
} // namespace util
6171
} // namespace xla

third_party/xla_client/multi_wait.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <condition_variable>
55
#include <functional>
6+
#include <memory>
67
#include <mutex>
78

89
#include "tensorflow/compiler/xla/types.h"
@@ -31,10 +32,19 @@ class MultiWait {
3132

3233
// Creates a completer functor which signals the mult wait object once func
3334
// has completed. Handles exceptions by signaling the multi wait with the
34-
// proper status value.
35+
// proper status value. This API returns a function which captures a MultiWait
36+
// reference, so care must be taken such that the reference remains valid for
37+
// the whole lifetime of the returned function.
3538
std::function<void()> Completer(std::function<void()> func);
3639

40+
// Similar as the above API, but with explicit capture of the MultiWait shared
41+
// pointer.
42+
static std::function<void()> Completer(std::shared_ptr<MultiWait> mwait,
43+
std::function<void()> func);
44+
3745
private:
46+
void Complete(const std::function<void()>& func);
47+
3848
std::mutex mutex_;
3949
std::condition_variable cv_;
4050
size_t count_ = 0;

third_party/xla_client/xrt_computation_client.cc

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
302302
}
303303
XLA_COUNTER("XrtPartitionedTransferToServer", 1);
304304

305-
util::MultiWait mwait(partitions.size());
305+
auto mwait = std::make_shared<util::MultiWait>(partitions.size());
306306
std::vector<DataPtr> results(tensors.size());
307307
for (size_t i = 0; i < partitions.size(); ++i) {
308308
auto sender = [&, i]() {
@@ -316,9 +316,10 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
316316
results[base_index + r] = std::move(partitions_results[r]);
317317
}
318318
};
319-
env::ScheduleIoClosure(mwait.Completer(std::move(sender)));
319+
env::ScheduleIoClosure(
320+
util::MultiWait::Completer(mwait, std::move(sender)));
320321
}
321-
mwait.Wait();
322+
mwait->Wait();
322323
return results;
323324
}
324325

@@ -330,7 +331,7 @@ XrtComputationClient::TransferToServerInternal(
330331
std::mutex lock;
331332
XrtSessionCache::SessionMap session_map;
332333
int64 total_size = 0;
333-
util::MultiWait mwait(tensors.size());
334+
auto mwait = std::make_shared<util::MultiWait>(tensors.size());
334335
std::map<XrtSession*, SessionWork> session_work_map;
335336
{
336337
metrics::TimedSection timed(TransferToServerTransformMetric());
@@ -363,13 +364,14 @@ XrtComputationClient::TransferToServerInternal(
363364
total_size += tdata.size();
364365
}
365366
};
366-
env::ScheduleClosure(mwait.Completer(std::move(converter)));
367+
env::ScheduleClosure(
368+
util::MultiWait::Completer(mwait, std::move(converter)));
367369
}
368-
mwait.Wait();
370+
mwait->Wait();
369371
}
370372
OutboundDataMetric()->AddSample(total_size);
371373

372-
mwait.Reset(session_work_map.size());
374+
mwait->Reset(session_work_map.size());
373375
std::vector<DataPtr> results(tensors.size());
374376
for (auto& session_session_work : session_work_map) {
375377
XrtSession* session = session_session_work.first;
@@ -388,9 +390,10 @@ XrtComputationClient::TransferToServerInternal(
388390
}
389391
CreateDataHandlesCounter()->AddValue(outputs.size());
390392
};
391-
env::ScheduleIoClosure(mwait.Completer(std::move(runner)));
393+
env::ScheduleIoClosure(
394+
util::MultiWait::Completer(mwait, std::move(runner)));
392395
}
393-
mwait.Wait();
396+
mwait->Wait();
394397
return results;
395398
}
396399

@@ -426,7 +429,7 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
426429
session_work->index_mapping.push_back(i);
427430
}
428431

429-
util::MultiWait mwait(session_work_map.size());
432+
auto mwait = std::make_shared<util::MultiWait>(session_work_map.size());
430433
std::atomic<int64> total_size(0);
431434
std::vector<Literal> results(handles.size());
432435
for (auto& session_session_work : session_work_map) {
@@ -446,9 +449,10 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
446449
total_size += results[li].size_bytes();
447450
}
448451
};
449-
env::ScheduleIoClosure(mwait.Completer(std::move(runner)));
452+
env::ScheduleIoClosure(
453+
util::MultiWait::Completer(mwait, std::move(runner)));
450454
}
451-
mwait.Wait();
455+
mwait->Wait();
452456
InboundDataMetric()->AddSample(total_size.load());
453457
return results;
454458
}
@@ -458,7 +462,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
458462
metrics::TimedSection timed(CompileMetric());
459463

460464
std::mutex lock;
461-
util::MultiWait mwait(instances.size());
465+
auto mwait = std::make_shared<util::MultiWait>(instances.size());
462466
std::vector<ProgramShape> program_shapes(instances.size());
463467
std::vector<ComputationPtr> results(instances.size());
464468
std::vector<CompilationCacheKey> cache_keys(instances.size());
@@ -499,10 +503,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
499503
results[i] = computation_ptr;
500504
}
501505
};
502-
env::ScheduleClosure(mwait.Completer(std::move(builder)));
506+
env::ScheduleClosure(util::MultiWait::Completer(mwait, std::move(builder)));
503507
}
504-
mwait.Wait();
505-
mwait.Reset(session_work_map.size());
508+
mwait->Wait();
509+
mwait->Reset(session_work_map.size());
506510

507511
for (auto& session_and_work : session_work_map) {
508512
XrtSession* session = session_and_work.first;
@@ -532,9 +536,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
532536
CreateCompileHandlesCounter()->AddValue(1);
533537
}
534538
};
535-
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
539+
env::ScheduleIoClosure(
540+
util::MultiWait::Completer(mwait, std::move(session_runner)));
536541
}
537-
mwait.Wait();
542+
mwait->Wait();
538543
return results;
539544
}
540545

@@ -626,7 +631,7 @@ XrtComputationClient::RunComputations(
626631
}
627632
XLA_CHECK_EQ(computations.size(), devices.size());
628633

629-
util::MultiWait mwait(session_replicas.size());
634+
auto mwait = std::make_shared<util::MultiWait>(session_replicas.size());
630635
std::vector<std::vector<DataPtr>> results(devices.size());
631636
for (auto& sess_replica : session_replicas) {
632637
XrtSession* session = sess_replica.first;
@@ -655,9 +660,10 @@ XrtComputationClient::RunComputations(
655660
GetEffectiveDevice(devices[replica]));
656661
}
657662
};
658-
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
663+
env::ScheduleIoClosure(
664+
util::MultiWait::Completer(mwait, std::move(session_runner)));
659665
}
660-
mwait.Wait();
666+
mwait->Wait();
661667
return results;
662668
}
663669

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset,
480480
std::min<size_t>(num_threads, std::thread::hardware_concurrency());
481481
size_t block_size = size / num_threads;
482482

483-
xla::util::MultiWait mwait(num_threads);
483+
auto mwait = std::make_shared<xla::util::MultiWait>(num_threads);
484484
for (size_t i = 0; i < num_threads; ++i) {
485485
auto reader = [&, i]() {
486486
uint64_t base = static_cast<uint64_t>(i) * block_size;
@@ -491,9 +491,10 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset,
491491
XLA_CHECK_OK(
492492
file->Read(offset + base, tsize, &result, buffer.get() + base));
493493
};
494-
xla::env::ScheduleIoClosure(mwait.Completer(std::move(reader)));
494+
xla::env::ScheduleIoClosure(
495+
xla::util::MultiWait::Completer(mwait, std::move(reader)));
495496
}
496-
mwait.Wait();
497+
mwait->Wait();
497498
}
498499
return py::bytes(buffer.get(), size);
499500
}

torch_xla/csrc/tensor_util.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,15 +423,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape,
423423
std::vector<xla::int64> iter_dims = GetIterationDimensions(dest_shape);
424424
std::vector<CopyPartition> parts =
425425
CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front());
426-
xla::util::MultiWait mwait(parts.size());
426+
auto mwait = std::make_shared<xla::util::MultiWait>(parts.size());
427427
for (size_t i = 0; i < parts.size(); ++i) {
428428
auto copy_fn = [&, i]() {
429429
SlicedCopy<SType, DType>(dest_shape.dimensions(), src_data, src_strides,
430430
dest_data, dest_strides, iter_dims, parts[i]);
431431
};
432-
xla::env::ScheduleClosure(mwait.Completer(std::move(copy_fn)));
432+
xla::env::ScheduleClosure(
433+
xla::util::MultiWait::Completer(mwait, std::move(copy_fn)));
433434
}
434-
mwait.Wait();
435+
mwait->Wait();
435436
}
436437
}
437438

0 commit comments

Comments
 (0)