@@ -302,7 +302,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
302
302
}
303
303
XLA_COUNTER (" XrtPartitionedTransferToServer" , 1 );
304
304
305
- util::MultiWait mwait (partitions.size ());
305
+ auto mwait = std::make_shared< util::MultiWait> (partitions.size ());
306
306
std::vector<DataPtr> results (tensors.size ());
307
307
for (size_t i = 0 ; i < partitions.size (); ++i) {
308
308
auto sender = [&, i]() {
@@ -316,9 +316,10 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
316
316
results[base_index + r] = std::move (partitions_results[r]);
317
317
}
318
318
};
319
- env::ScheduleIoClosure (mwait.Completer (std::move (sender)));
319
+ env::ScheduleIoClosure (
320
+ util::MultiWait::Completer (mwait, std::move (sender)));
320
321
}
321
- mwait. Wait ();
322
+ mwait-> Wait ();
322
323
return results;
323
324
}
324
325
@@ -330,7 +331,7 @@ XrtComputationClient::TransferToServerInternal(
330
331
std::mutex lock;
331
332
XrtSessionCache::SessionMap session_map;
332
333
int64 total_size = 0 ;
333
- util::MultiWait mwait (tensors.size ());
334
+ auto mwait = std::make_shared< util::MultiWait> (tensors.size ());
334
335
std::map<XrtSession*, SessionWork> session_work_map;
335
336
{
336
337
metrics::TimedSection timed (TransferToServerTransformMetric ());
@@ -363,13 +364,14 @@ XrtComputationClient::TransferToServerInternal(
363
364
total_size += tdata.size ();
364
365
}
365
366
};
366
- env::ScheduleClosure (mwait.Completer (std::move (converter)));
367
+ env::ScheduleClosure (
368
+ util::MultiWait::Completer (mwait, std::move (converter)));
367
369
}
368
- mwait. Wait ();
370
+ mwait-> Wait ();
369
371
}
370
372
OutboundDataMetric ()->AddSample (total_size);
371
373
372
- mwait. Reset (session_work_map.size ());
374
+ mwait-> Reset (session_work_map.size ());
373
375
std::vector<DataPtr> results (tensors.size ());
374
376
for (auto & session_session_work : session_work_map) {
375
377
XrtSession* session = session_session_work.first ;
@@ -388,9 +390,10 @@ XrtComputationClient::TransferToServerInternal(
388
390
}
389
391
CreateDataHandlesCounter ()->AddValue (outputs.size ());
390
392
};
391
- env::ScheduleIoClosure (mwait.Completer (std::move (runner)));
393
+ env::ScheduleIoClosure (
394
+ util::MultiWait::Completer (mwait, std::move (runner)));
392
395
}
393
- mwait. Wait ();
396
+ mwait-> Wait ();
394
397
return results;
395
398
}
396
399
@@ -426,7 +429,7 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
426
429
session_work->index_mapping .push_back (i);
427
430
}
428
431
429
- util::MultiWait mwait (session_work_map.size ());
432
+ auto mwait = std::make_shared< util::MultiWait> (session_work_map.size ());
430
433
std::atomic<int64> total_size (0 );
431
434
std::vector<Literal> results (handles.size ());
432
435
for (auto & session_session_work : session_work_map) {
@@ -446,9 +449,10 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
446
449
total_size += results[li].size_bytes ();
447
450
}
448
451
};
449
- env::ScheduleIoClosure (mwait.Completer (std::move (runner)));
452
+ env::ScheduleIoClosure (
453
+ util::MultiWait::Completer (mwait, std::move (runner)));
450
454
}
451
- mwait. Wait ();
455
+ mwait-> Wait ();
452
456
InboundDataMetric ()->AddSample (total_size.load ());
453
457
return results;
454
458
}
@@ -458,7 +462,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
458
462
metrics::TimedSection timed (CompileMetric ());
459
463
460
464
std::mutex lock;
461
- util::MultiWait mwait (instances.size ());
465
+ auto mwait = std::make_shared< util::MultiWait> (instances.size ());
462
466
std::vector<ProgramShape> program_shapes (instances.size ());
463
467
std::vector<ComputationPtr> results (instances.size ());
464
468
std::vector<CompilationCacheKey> cache_keys (instances.size ());
@@ -499,10 +503,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
499
503
results[i] = computation_ptr;
500
504
}
501
505
};
502
- env::ScheduleClosure (mwait. Completer (std::move (builder)));
506
+ env::ScheduleClosure (util::MultiWait:: Completer (mwait, std::move (builder)));
503
507
}
504
- mwait. Wait ();
505
- mwait. Reset (session_work_map.size ());
508
+ mwait-> Wait ();
509
+ mwait-> Reset (session_work_map.size ());
506
510
507
511
for (auto & session_and_work : session_work_map) {
508
512
XrtSession* session = session_and_work.first ;
@@ -532,9 +536,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
532
536
CreateCompileHandlesCounter ()->AddValue (1 );
533
537
}
534
538
};
535
- env::ScheduleIoClosure (mwait.Completer (std::move (session_runner)));
539
+ env::ScheduleIoClosure (
540
+ util::MultiWait::Completer (mwait, std::move (session_runner)));
536
541
}
537
- mwait. Wait ();
542
+ mwait-> Wait ();
538
543
return results;
539
544
}
540
545
@@ -626,7 +631,7 @@ XrtComputationClient::RunComputations(
626
631
}
627
632
XLA_CHECK_EQ (computations.size (), devices.size ());
628
633
629
- util::MultiWait mwait (session_replicas.size ());
634
+ auto mwait = std::make_shared< util::MultiWait> (session_replicas.size ());
630
635
std::vector<std::vector<DataPtr>> results (devices.size ());
631
636
for (auto & sess_replica : session_replicas) {
632
637
XrtSession* session = sess_replica.first ;
@@ -655,9 +660,10 @@ XrtComputationClient::RunComputations(
655
660
GetEffectiveDevice (devices[replica]));
656
661
}
657
662
};
658
- env::ScheduleIoClosure (mwait.Completer (std::move (session_runner)));
663
+ env::ScheduleIoClosure (
664
+ util::MultiWait::Completer (mwait, std::move (session_runner)));
659
665
}
660
- mwait. Wait ();
666
+ mwait-> Wait ();
661
667
return results;
662
668
}
663
669
0 commit comments