@@ -33,7 +33,12 @@ DistributedRuntimeClient::DistributedRuntimeClient(
3333 options_ (options) {}
3434
3535DistributedRuntimeClient::~DistributedRuntimeClient () {
36- if (state_ == State::kConnected ) {
36+ bool connected;
37+ {
38+ absl::MutexLock lock (&mu_);
39+ connected = (state_ == State::kConnected );
40+ }
41+ if (connected) {
3742 if (options_.shutdown_on_destruction ) {
3843 Status status = Shutdown ();
3944 if (!status.ok ()) {
@@ -54,15 +59,20 @@ DistributedRuntimeClient::~DistributedRuntimeClient() {
5459 return " kNotConnected" ;
5560 case State::kConnected :
5661 return " kConnected" ;
62+ case State::kShuttingDown :
63+ return " kShuttingDown" ;
5764 case State::kClosed :
5865 return " kClosed" ;
5966 }
6067}
6168
6269xla::Status DistributedRuntimeClient::Connect () {
63- if (state_ != State::kNotConnected ) {
64- return xla::FailedPrecondition (" Connect() called when client in state %s" ,
65- StateToString (state_));
70+ {
71+ absl::MutexLock lock (&mu_);
72+ if (state_ != State::kNotConnected ) {
73+ return xla::FailedPrecondition (" Connect() called when client in state %s" ,
74+ StateToString (state_));
75+ }
6676 }
6777 ConnectRequest request;
6878 request.set_protocol_version (kDistributedRuntimeProtocolVersion );
@@ -107,7 +117,10 @@ xla::Status DistributedRuntimeClient::Connect() {
107117 FromGrpcStatus (status).ToString ()));
108118 }
109119 VLOG (10 ) << " Connect() response: " << response.DebugString ();
110- state_ = State::kConnected ;
120+ {
121+ absl::MutexLock lock (&mu_);
122+ state_ = State::kConnected ;
123+ }
111124 session_id_ = response.session_id ();
112125
113126 heartbeat_thread_.reset (options_.env ->StartThread (
@@ -120,9 +133,12 @@ xla::Status DistributedRuntimeClient::Connect() {
120133xla::Status DistributedRuntimeClient::EnumerateDevices (
121134 const LocalTopologyProto& local_topology,
122135 GlobalTopologyProto* global_topology) {
123- if (state_ != State::kConnected ) {
124- return xla::FailedPrecondition (
125- " EnumerateDevices() called when client not connected." );
136+ {
137+ absl::MutexLock lock (&mu_);
138+ if (state_ != State::kConnected ) {
139+ return xla::FailedPrecondition (
140+ " EnumerateDevices() called when client not connected." );
141+ }
126142 }
127143 ::grpc::ClientContext ctx;
128144 ctx.set_fail_fast (false );
@@ -146,9 +162,13 @@ xla::Status DistributedRuntimeClient::EnumerateDevices(
146162xla::Status DistributedRuntimeClient::Shutdown () {
147163 LOG (INFO) << " Waiting for all distributed JAX tasks to shut down." ;
148164 ::grpc::ClientContext ctx;
149- if (state_ != State::kConnected ) {
150- return xla::FailedPrecondition (
151- " Shutdown() called when client not connected." );
165+ {
166+ absl::MutexLock lock (&mu_);
167+ if (state_ != State::kConnected ) {
168+ return xla::FailedPrecondition (
169+ " Shutdown() called when client not connected." );
170+ }
171+ state_ = State::kShuttingDown ;
152172 }
153173 ctx.set_fail_fast (false );
154174 ctx.set_deadline (absl::ToChronoTime (absl::Now () + options_.shutdown_timeout ));
@@ -165,15 +185,19 @@ xla::Status DistributedRuntimeClient::Shutdown() {
165185 stop_heartbeats_.Notify ();
166186 }
167187 VLOG (10 ) << " Shutdown() response: " << response.DebugString ();
188+ absl::MutexLock lock (&mu_);
168189 state_ = State::kClosed ;
169190 return xla::Status::OK ();
170191}
171192
172193xla::StatusOr<std::string> DistributedRuntimeClient::BlockingKeyValueGet (
173194 std::string key, absl::Duration timeout) {
174- if (state_ != State::kConnected ) {
175- return xla::FailedPrecondition (
176- " BlockingKeyValueGet() called when client not connected." );
195+ {
196+ absl::MutexLock lock (&mu_);
197+ if (state_ != State::kConnected ) {
198+ return xla::FailedPrecondition (
199+ " BlockingKeyValueGet() called when client not connected." );
200+ }
177201 }
178202 ::grpc::ClientContext ctx;
179203 ctx.set_fail_fast (false );
@@ -194,9 +218,12 @@ xla::StatusOr<std::string> DistributedRuntimeClient::BlockingKeyValueGet(
194218
195219xla::Status DistributedRuntimeClient::KeyValueSet (std::string key,
196220 std::string value) {
197- if (state_ != State::kConnected ) {
198- return xla::FailedPrecondition (
199- " KeyValueSet() called when client not connected." );
221+ {
222+ absl::MutexLock lock (&mu_);
223+ if (state_ != State::kConnected ) {
224+ return xla::FailedPrecondition (
225+ " KeyValueSet() called when client not connected." );
226+ }
200227 }
201228 ::grpc::ClientContext ctx;
202229 ctx.set_fail_fast (false );
@@ -239,8 +266,14 @@ void DistributedRuntimeClient::HeartbeatLoop() {
239266 if (!stop_heartbeats_.HasBeenNotified () &&
240267 (!is_transient_error ||
241268 num_missing_heartbeats > options_.max_missing_heartbeats )) {
242- options_.missed_heartbeat_callback (FromGrpcStatus (status),
243- !is_transient_error);
269+ // If we are shutting down, missed heartbeats are benign: they may
270+ // simply mean that the server has shut down already before it saw
271+ // the heartbeat request.
272+ absl::MutexLock lock (&mu_);
273+ if (state_ != State::kShuttingDown ) {
274+ options_.missed_heartbeat_callback (FromGrpcStatus (status),
275+ !is_transient_error);
276+ }
244277 return ;
245278 }
246279 }
0 commit comments