Skip to content

Commit f4cc166

Browse files
committed
xds: Copy data in least request to avoid picker data race
In 0d39bf5 the ReadyPicker was changed holding List<Subchannel> to List<ChildLbState>, but ChildLbState mutates over time and is not synchronized. We want the picker to have a snapshot of the data, so copy the data from ChildLbState instead of using it directly. Unfortunately the tests depended on the ChildLbState a bit, so we need to save the EAG only to use it in tests. That's okay for now, but in the future we'll probably want to remove that unnecessary memory usage.
1 parent a231d80 commit f4cc166

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java

+30-15
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import io.grpc.ClientStreamTracer;
3333
import io.grpc.ClientStreamTracer.StreamInfo;
3434
import io.grpc.ConnectivityState;
35+
import io.grpc.EquivalentAddressGroup;
3536
import io.grpc.LoadBalancer;
3637
import io.grpc.LoadBalancerProvider;
3738
import io.grpc.Metadata;
@@ -153,28 +154,37 @@ private static AtomicInteger getInFlights(ChildLbState childLbState) {
153154

154155
@VisibleForTesting
155156
static final class ReadyPicker extends SubchannelPicker {
156-
private final List<ChildLbState> childLbStates; // non-empty
157+
private final List<SubchannelPicker> childPickers; // non-empty
158+
private final List<AtomicInteger> childInFlights; // 1:1 with childPickers
159+
private final List<EquivalentAddressGroup> childEags; // 1:1 with childPickers
157160
private final int choiceCount;
158161
private final ThreadSafeRandom random;
159162
private final int hashCode;
160163

161164
ReadyPicker(List<ChildLbState> childLbStates, int choiceCount, ThreadSafeRandom random) {
162165
checkArgument(!childLbStates.isEmpty(), "empty list");
163-
this.childLbStates = childLbStates;
166+
this.childPickers = new ArrayList<>(childLbStates.size());
167+
this.childInFlights = new ArrayList<>(childLbStates.size());
168+
this.childEags = new ArrayList<>(childLbStates.size());
169+
for (ChildLbState state : childLbStates) {
170+
childPickers.add(state.getCurrentPicker());
171+
childInFlights.add(getInFlights(state));
172+
childEags.add(state.getEag());
173+
}
164174
this.choiceCount = choiceCount;
165175
this.random = checkNotNull(random, "random");
166176

167177
int sum = 0;
168-
for (ChildLbState child : childLbStates) {
178+
for (SubchannelPicker child : childPickers) {
169179
sum += child.hashCode();
170180
}
171181
this.hashCode = sum ^ choiceCount;
172182
}
173183

174184
@Override
175185
public PickResult pickSubchannel(PickSubchannelArgs args) {
176-
final ChildLbState childLbState = nextChildToUse();
177-
PickResult childResult = childLbState.getCurrentPicker().pickSubchannel(args);
186+
int child = nextChildToUse();
187+
PickResult childResult = childPickers.get(child).pickSubchannel(args);
178188

179189
if (!childResult.getStatus().isOk() || childResult.getSubchannel() == null) {
180190
return childResult;
@@ -186,33 +196,38 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
186196
} else {
187197
// Wrap the subchannel
188198
OutstandingRequestsTracingFactory factory =
189-
new OutstandingRequestsTracingFactory(getInFlights(childLbState));
199+
new OutstandingRequestsTracingFactory(childInFlights.get(child));
190200
return PickResult.withSubchannel(childResult.getSubchannel(), factory);
191201
}
192202
}
193203

194204
@Override
195205
public String toString() {
196206
return MoreObjects.toStringHelper(ReadyPicker.class)
197-
.add("list", childLbStates)
207+
.add("list", childPickers)
198208
.add("choiceCount", choiceCount)
199209
.toString();
200210
}
201211

202-
private ChildLbState nextChildToUse() {
203-
ChildLbState candidate = childLbStates.get(random.nextInt(childLbStates.size()));
212+
private int nextChildToUse() {
213+
int candidate = random.nextInt(childPickers.size());
204214
for (int i = 0; i < choiceCount - 1; ++i) {
205-
ChildLbState sampled = childLbStates.get(random.nextInt(childLbStates.size()));
206-
if (getInFlights(sampled).get() < getInFlights(candidate).get()) {
215+
int sampled = random.nextInt(childPickers.size());
216+
if (childInFlights.get(sampled).get() < childInFlights.get(candidate).get()) {
207217
candidate = sampled;
208218
}
209219
}
210220
return candidate;
211221
}
212222

213223
@VisibleForTesting
214-
List<ChildLbState> getChildLbStates() {
215-
return childLbStates;
224+
List<SubchannelPicker> getChildPickers() {
225+
return childPickers;
226+
}
227+
228+
@VisibleForTesting
229+
List<EquivalentAddressGroup> getChildEags() {
230+
return childEags;
216231
}
217232

218233
@Override
@@ -232,8 +247,8 @@ public boolean equals(Object o) {
232247
// the lists cannot contain duplicate children
233248
return hashCode == other.hashCode
234249
&& choiceCount == other.choiceCount
235-
&& childLbStates.size() == other.childLbStates.size()
236-
&& new HashSet<>(childLbStates).containsAll(other.childLbStates);
250+
&& childPickers.size() == other.childPickers.size()
251+
&& new HashSet<>(childPickers).containsAll(other.childPickers);
237252
}
238253
}
239254

xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,12 @@ private String getStatusString(SubchannelPicker picker) {
384384
}
385385

386386
if (picker instanceof ReadyPicker) {
387-
List<ChildLbState> childLbStates = ((ReadyPicker)picker).getChildLbStates();
388-
if (childLbStates == null || childLbStates.isEmpty()) {
387+
List<SubchannelPicker> childPickers = ((ReadyPicker)picker).getChildPickers();
388+
if (childPickers == null || childPickers.isEmpty()) {
389389
return "";
390390
}
391391

392-
picker = childLbStates.get(0).getCurrentPicker();
392+
picker = childPickers.get(0);
393393
}
394394

395395
Status status = picker.pickSubchannel(mockArgs).getStatus();
@@ -460,7 +460,8 @@ public void pickerLeastRequest() throws Exception {
460460

461461
ReadyPicker picker = (ReadyPicker) pickerCaptor.getValue();
462462

463-
assertThat(picker.getChildLbStates()).containsExactlyElementsIn(childLbStates);
463+
assertThat(picker.getChildEags())
464+
.containsExactlyElementsIn(childLbStates.stream().map(ChildLbState::getEag).toArray());
464465

465466
// Make random return 0, then 2 for the sample indexes.
466467
when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2);
@@ -647,7 +648,7 @@ public void emptyAddresses() {
647648

648649
private List<Subchannel> getList(SubchannelPicker picker) {
649650
if (picker instanceof ReadyPicker) {
650-
return ((ReadyPicker) picker).getChildLbStates().stream()
651+
return ((ReadyPicker) picker).getChildEags().stream()
651652
.map(this::getSubchannel)
652653
.collect(Collectors.toList());
653654
} else {

0 commit comments

Comments
 (0)