Skip to content

Commit 3b50f01

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[JAX] Add a shutting-down state to the distributed PjRt client to suppress heartbeat errors while shutting down.
PiperOrigin-RevId: 345080074 Change-Id: I0c7488eab91a1260b905b7ef011dd3a405ac8fae
1 parent 70038da commit 3b50f01

File tree

2 files changed

+60
-20
lines changed

2 files changed

+60
-20
lines changed

tensorflow/compiler/xla/pjrt/distributed/client.cc

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ DistributedRuntimeClient::DistributedRuntimeClient(
3333
options_(options) {}
3434

3535
DistributedRuntimeClient::~DistributedRuntimeClient() {
36-
if (state_ == State::kConnected) {
36+
bool connected;
37+
{
38+
absl::MutexLock lock(&mu_);
39+
connected = (state_ == State::kConnected);
40+
}
41+
if (connected) {
3742
if (options_.shutdown_on_destruction) {
3843
Status status = Shutdown();
3944
if (!status.ok()) {
@@ -54,15 +59,20 @@ DistributedRuntimeClient::~DistributedRuntimeClient() {
5459
return "kNotConnected";
5560
case State::kConnected:
5661
return "kConnected";
62+
case State::kShuttingDown:
63+
return "kShuttingDown";
5764
case State::kClosed:
5865
return "kClosed";
5966
}
6067
}
6168

6269
xla::Status DistributedRuntimeClient::Connect() {
63-
if (state_ != State::kNotConnected) {
64-
return xla::FailedPrecondition("Connect() called when client in state %s",
65-
StateToString(state_));
70+
{
71+
absl::MutexLock lock(&mu_);
72+
if (state_ != State::kNotConnected) {
73+
return xla::FailedPrecondition("Connect() called when client in state %s",
74+
StateToString(state_));
75+
}
6676
}
6777
ConnectRequest request;
6878
request.set_protocol_version(kDistributedRuntimeProtocolVersion);
@@ -107,7 +117,10 @@ xla::Status DistributedRuntimeClient::Connect() {
107117
FromGrpcStatus(status).ToString()));
108118
}
109119
VLOG(10) << "Connect() response: " << response.DebugString();
110-
state_ = State::kConnected;
120+
{
121+
absl::MutexLock lock(&mu_);
122+
state_ = State::kConnected;
123+
}
111124
session_id_ = response.session_id();
112125

113126
heartbeat_thread_.reset(options_.env->StartThread(
@@ -120,9 +133,12 @@ xla::Status DistributedRuntimeClient::Connect() {
120133
xla::Status DistributedRuntimeClient::EnumerateDevices(
121134
const LocalTopologyProto& local_topology,
122135
GlobalTopologyProto* global_topology) {
123-
if (state_ != State::kConnected) {
124-
return xla::FailedPrecondition(
125-
"EnumerateDevices() called when client not connected.");
136+
{
137+
absl::MutexLock lock(&mu_);
138+
if (state_ != State::kConnected) {
139+
return xla::FailedPrecondition(
140+
"EnumerateDevices() called when client not connected.");
141+
}
126142
}
127143
::grpc::ClientContext ctx;
128144
ctx.set_fail_fast(false);
@@ -146,9 +162,13 @@ xla::Status DistributedRuntimeClient::EnumerateDevices(
146162
xla::Status DistributedRuntimeClient::Shutdown() {
147163
LOG(INFO) << "Waiting for all distributed JAX tasks to shut down.";
148164
::grpc::ClientContext ctx;
149-
if (state_ != State::kConnected) {
150-
return xla::FailedPrecondition(
151-
"Shutdown() called when client not connected.");
165+
{
166+
absl::MutexLock lock(&mu_);
167+
if (state_ != State::kConnected) {
168+
return xla::FailedPrecondition(
169+
"Shutdown() called when client not connected.");
170+
}
171+
state_ = State::kShuttingDown;
152172
}
153173
ctx.set_fail_fast(false);
154174
ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.shutdown_timeout));
@@ -165,15 +185,19 @@ xla::Status DistributedRuntimeClient::Shutdown() {
165185
stop_heartbeats_.Notify();
166186
}
167187
VLOG(10) << "Shutdown() response: " << response.DebugString();
188+
absl::MutexLock lock(&mu_);
168189
state_ = State::kClosed;
169190
return xla::Status::OK();
170191
}
171192

172193
xla::StatusOr<std::string> DistributedRuntimeClient::BlockingKeyValueGet(
173194
std::string key, absl::Duration timeout) {
174-
if (state_ != State::kConnected) {
175-
return xla::FailedPrecondition(
176-
"BlockingKeyValueGet() called when client not connected.");
195+
{
196+
absl::MutexLock lock(&mu_);
197+
if (state_ != State::kConnected) {
198+
return xla::FailedPrecondition(
199+
"BlockingKeyValueGet() called when client not connected.");
200+
}
177201
}
178202
::grpc::ClientContext ctx;
179203
ctx.set_fail_fast(false);
@@ -194,9 +218,12 @@ xla::StatusOr<std::string> DistributedRuntimeClient::BlockingKeyValueGet(
194218

195219
xla::Status DistributedRuntimeClient::KeyValueSet(std::string key,
196220
std::string value) {
197-
if (state_ != State::kConnected) {
198-
return xla::FailedPrecondition(
199-
"KeyValueSet() called when client not connected.");
221+
{
222+
absl::MutexLock lock(&mu_);
223+
if (state_ != State::kConnected) {
224+
return xla::FailedPrecondition(
225+
"KeyValueSet() called when client not connected.");
226+
}
200227
}
201228
::grpc::ClientContext ctx;
202229
ctx.set_fail_fast(false);
@@ -239,8 +266,14 @@ void DistributedRuntimeClient::HeartbeatLoop() {
239266
if (!stop_heartbeats_.HasBeenNotified() &&
240267
(!is_transient_error ||
241268
num_missing_heartbeats > options_.max_missing_heartbeats)) {
242-
options_.missed_heartbeat_callback(FromGrpcStatus(status),
243-
!is_transient_error);
269+
// If we are shutting down, missed heartbeats are benign: they may
270+
// simply mean that the server has shut down already before it saw
271+
// the heartbeat request.
272+
absl::MutexLock lock(&mu_);
273+
if (state_ != State::kShuttingDown) {
274+
options_.missed_heartbeat_callback(FromGrpcStatus(status),
275+
!is_transient_error);
276+
}
244277
return;
245278
}
246279
}

tensorflow/compiler/xla/pjrt/distributed/client.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,21 @@ class DistributedRuntimeClient {
134134
// connection is healthy.
135135
kConnected,
136136

137+
// The client is in the process of shutting down, i.e., Shutdown() has been
138+
// called.
139+
kShuttingDown,
140+
137141
// The client has shut down its server connection, either due to an error
138142
// or due to an explicit shutdown.
139143
kClosed,
140144
};
141145

142146
static absl::string_view StateToString(State state);
143147

144-
State state_ = State::kNotConnected;
148+
// state_ is protected by a mutex because the heartbeat thread needs to look
149+
// at it.
150+
absl::Mutex mu_;
151+
State state_ GUARDED_BY(mu_) = State::kNotConnected;
145152

146153
// A unique session ID, assigned by the server during Connect().
147154
uint64 session_id_;

0 commit comments

Comments
 (0)