Skip to content

Fix race condition in RestCancellableNodeClient #126686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126686.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126686
summary: Fix race condition in `RestCancellableNodeClient`
area: Task Management
type: bug
issues:
- 88201
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,12 @@
import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction;
import org.elasticsearch.client.Request;
import org.elasticsearch.test.junit.annotations.TestIssueLogging;

public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase {
@TestIssueLogging(
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
+ ",org.elasticsearch.transport.TransportService:TRACE"
)
public void testIndicesSegmentsRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME);
}

@TestIssueLogging(
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
+ ",org.elasticsearch.transport.TransportService:TRACE"
)
public void testCatSegmentsRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import org.elasticsearch.client.internal.FilterClient;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -112,12 +112,14 @@ private void cancelTask(TaskId taskId) {

private class CloseListener implements ActionListener<Void> {
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
private final Set<TaskId> tasks = new HashSet<>();

@Nullable // if already drained
private Set<TaskId> tasks = new HashSet<>();

CloseListener() {}

synchronized int getNumTasks() {
return tasks.size();
return tasks == null ? 0 : tasks.size();
}

void maybeRegisterChannel(HttpChannel httpChannel) {
Expand All @@ -130,16 +132,23 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
}
}

synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
taskHolder.taskId = taskId;
if (taskHolder.completed == false) {
this.tasks.add(taskId);
void registerTask(TaskHolder taskHolder, TaskId taskId) {
synchronized (this) {
taskHolder.taskId = taskId;
if (tasks != null) {
if (taskHolder.completed == false) {
tasks.add(taskId);
}
return;
}
}
// else tasks == null so the channel is already closed
cancelTask(taskId);
}

synchronized void unregisterTask(TaskHolder taskHolder) {
if (taskHolder.taskId != null) {
this.tasks.remove(taskHolder.taskId);
if (taskHolder.taskId != null && tasks != null) {
tasks.remove(taskHolder.taskId);
}
taskHolder.completed = true;
}
Expand All @@ -149,18 +158,20 @@ public void onResponse(Void aVoid) {
final HttpChannel httpChannel = channel.get();
assert httpChannel != null : "channel not registered";
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
CloseListener closeListener = httpChannels.remove(httpChannel);
assert closeListener != null : "channel not found in the map of tracked channels";
final List<TaskId> toCancel;
synchronized (this) {
toCancel = new ArrayList<>(tasks);
tasks.clear();
}
for (TaskId taskId : toCancel) {
final CloseListener closeListener = httpChannels.remove(httpChannel);
assert closeListener != null : "channel not found in the map of tracked channels: " + httpChannel;
assert closeListener == CloseListener.this : "channel had a different CloseListener registered: " + httpChannel;
for (final var taskId : drainTasks()) {
cancelTask(taskId);
}
}

private synchronized Collection<TaskId> drainTasks() {
final var drained = tasks;
tasks = null;
return drained;
}

@Override
public void onFailure(Exception e) {
onResponse(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.tasks.Task;
Expand All @@ -44,6 +45,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.LongSupplier;

public class RestCancellableNodeClientTests extends ESTestCase {

Expand Down Expand Up @@ -148,8 +150,42 @@ public void testChannelAlreadyClosed() {
assertEquals(totalSearches, testClient.cancelledTasks.size());
}

public void testConcurrentExecuteAndClose() throws Exception {
final var testClient = new TestClient(Settings.EMPTY, threadPool, true);
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
int numTasks = randomIntBetween(1, 30);
TestHttpChannel channel = new TestHttpChannel();
final var startLatch = new CountDownLatch(1);
final var doneLatch = new CountDownLatch(numTasks + 1);
final var expectedTasks = Sets.<TaskId>newHashSetWithExpectedSize(numTasks);
for (int j = 0; j < numTasks; j++) {
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
threadPool.generic().execute(() -> {
client.execute(TransportSearchAction.TYPE, new SearchRequest(), ActionListener.running(ESTestCase::fail));
startLatch.countDown();
doneLatch.countDown();
});
expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j));
}
threadPool.generic().execute(() -> {
try {
safeAwait(startLatch);
channel.awaitClose();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new AssertionError(e);
} finally {
doneLatch.countDown();
}
});
safeAwait(doneLatch);
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
assertEquals(expectedTasks, testClient.cancelledTasks);
}

private static class TestClient extends NodeClient {
private final AtomicLong counter = new AtomicLong(0);
private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement;
private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement;
private final Set<TaskId> cancelledTasks = new CopyOnWriteArraySet<>();
private final AtomicInteger searchRequests = new AtomicInteger(0);
private final boolean timeout;
Expand All @@ -167,9 +203,17 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
) {
switch (action.name()) {
case TransportCancelTasksAction.NAME -> {
CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request;
assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTargetTaskId()));
Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap());
assertTrue(
"tried to cancel the same task more than once",
cancelledTasks.add(asInstanceOf(CancelTasksRequest.class, request).getTargetTaskId())
);
Task task = request.createTask(
cancelTaskIdGenerator.getAsLong(),
"cancel_task",
action.name(),
null,
Collections.emptyMap()
);
if (randomBoolean()) {
listener.onResponse(null);
} else {
Expand All @@ -180,7 +224,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
}
case TransportSearchAction.NAME -> {
searchRequests.incrementAndGet();
Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap());
Task searchTask = request.createTask(
searchTaskIdGenerator.getAsLong(),
"search",
action.name(),
null,
Collections.emptyMap()
);
if (timeout == false) {
if (rarely()) {
// make sure that search is sometimes also called from the same thread before the task is returned
Expand All @@ -191,7 +241,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
}
return searchTask;
}
default -> throw new UnsupportedOperationException();
default -> throw new AssertionError("unexpected action " + action.name());
}

}
Expand Down Expand Up @@ -222,10 +272,7 @@ public InetSocketAddress getRemoteAddress() {

@Override
public void close() {
if (open.compareAndSet(true, false) == false) {
assert false : "HttpChannel is already closed";
return; // nothing to do
}
assertTrue("HttpChannel is already closed", open.compareAndSet(true, false));
ActionListener<Void> listener = closeListener.get();
if (listener != null) {
boolean failure = randomBoolean();
Expand All @@ -241,6 +288,7 @@ public void close() {
}

private void awaitClose() throws InterruptedException {
assertNotNull("must set closeListener before calling awaitClose", closeListener.get());
close();
closeLatch.await();
}
Expand All @@ -257,7 +305,7 @@ public void addCloseListener(ActionListener<Void> listener) {
listener.onResponse(null);
} else {
if (closeListener.compareAndSet(null, listener) == false) {
throw new IllegalStateException("close listener already set, only one is allowed!");
throw new AssertionError("close listener already set, only one is allowed!");
}
}
}
Expand Down
Loading