Skip to content

Cheaper handling of skipped shard iterators in AbstractSearchAsyncAction (#124223) #126533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchContextId;
Expand Down Expand Up @@ -87,18 +86,18 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final SetOnce<AtomicArray<ShardSearchFailure>> shardFailures = new SetOnce<>();
private final Object shardFailuresMutex = new Object();
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
private final AtomicInteger successfulOps = new AtomicInteger();
private final AtomicInteger successfulOps;
private final SearchTimeProvider timeProvider;
private final SearchResponse.Clusters clusters;

protected final List<SearchShardIterator> toSkipShardsIts;
protected final List<SearchShardIterator> shardsIts;
private final SearchShardIterator[] shardIterators;
private final AtomicInteger outstandingShards;
private final int maxConcurrentRequestsPerNode;
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
private final boolean throttleConcurrentRequests;
private final AtomicBoolean requestCancelled = new AtomicBoolean();
private final int skippedCount;

// protected for tests
protected final SubscribableListener<Void> doneFuture = new SubscribableListener<>();
Expand All @@ -124,18 +123,19 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
) {
super(name);
this.namedWriteableRegistry = namedWriteableRegistry;
final List<SearchShardIterator> toSkipIterators = new ArrayList<>();
final List<SearchShardIterator> iterators = new ArrayList<>();
int skipped = 0;
for (final SearchShardIterator iterator : shardsIts) {
if (iterator.skip()) {
toSkipIterators.add(iterator);
skipped++;
} else {
iterators.add(iterator);
}
}
this.toSkipShardsIts = toSkipIterators;
this.skippedCount = skipped;
this.shardsIts = iterators;
outstandingShards = new AtomicInteger(shardsIts.size());
outstandingShards = new AtomicInteger(iterators.size());
successfulOps = new AtomicInteger(skipped);
this.shardIterators = iterators.toArray(new SearchShardIterator[0]);
// we later compute the shard index based on the natural order of the shards
// that participate in the search request. This means that this number is
Expand Down Expand Up @@ -166,11 +166,19 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
protected void notifyListShards(
SearchProgressListener progressListener,
SearchResponse.Clusters clusters,
SearchSourceBuilder sourceBuilder
SearchRequest searchRequest,
List<SearchShardIterator> allIterators
) {
final List<SearchShard> skipped = new ArrayList<>(allIterators.size() - shardsIts.size());
for (SearchShardIterator iter : allIterators) {
if (iter.skip()) {
skipped.add(new SearchShard(iter.getClusterAlias(), iter.shardId()));
}
}
var sourceBuilder = searchRequest.source();
progressListener.notifyListShards(
SearchProgressListener.buildSearchShardsFromIter(this.shardsIts),
SearchProgressListener.buildSearchShardsFromIter(toSkipShardsIts),
skipped,
clusters,
sourceBuilder == null || sourceBuilder.size() > 0,
timeProvider
Expand Down Expand Up @@ -219,44 +227,37 @@ public final void start() {

@Override
protected final void run() {
for (final SearchShardIterator iterator : toSkipShardsIts) {
assert iterator.skip();
skipShard(iterator);
if (outstandingShards.get() == 0) {
onPhaseDone();
return;
}
final Map<SearchShardIterator, Integer> shardIndexMap = Maps.newHashMapWithExpectedSize(shardIterators.length);
for (int i = 0; i < shardIterators.length; i++) {
shardIndexMap.put(shardIterators[i], i);
}
if (shardsIts.size() > 0) {
doCheckNoMissingShards(getName(), request, shardsIts);
Version version = request.minCompatibleShardNode();
if (version != null && Version.CURRENT.minimumCompatibilityVersion().equals(version) == false) {
if (checkMinimumVersion(shardsIts) == false) {
throw new VersionMismatchException(
"One of the shards is incompatible with the required minimum version [{}]",
request.minCompatibleShardNode()
);
}
}
for (int i = 0; i < shardsIts.size(); i++) {
final SearchShardIterator shardRoutings = shardsIts.get(i);
assert shardRoutings.skip() == false;
assert shardIndexMap.containsKey(shardRoutings);
int shardIndex = shardIndexMap.get(shardRoutings);
final SearchShardTarget routing = shardRoutings.nextOrNull();
if (routing == null) {
failOnUnavailable(shardIndex, shardRoutings);
} else {
performPhaseOnShard(shardIndex, shardRoutings, routing);
}
doCheckNoMissingShards(getName(), request, shardsIts);
Version version = request.minCompatibleShardNode();
if (version != null && Version.CURRENT.minimumCompatibilityVersion().equals(version) == false) {
if (checkMinimumVersion(shardsIts) == false) {
throw new VersionMismatchException(
"One of the shards is incompatible with the required minimum version [{}]",
request.minCompatibleShardNode()
);
}
}
}
for (int i = 0; i < shardsIts.size(); i++) {
final SearchShardIterator shardRoutings = shardsIts.get(i);
assert shardRoutings.skip() == false;
assert shardIndexMap.containsKey(shardRoutings);
int shardIndex = shardIndexMap.get(shardRoutings);
final SearchShardTarget routing = shardRoutings.nextOrNull();
if (routing == null) {
failOnUnavailable(shardIndex, shardRoutings);
} else {
performPhaseOnShard(shardIndex, shardRoutings, routing);
}

void skipShard(SearchShardIterator iterator) {
successfulOps.incrementAndGet();
assert iterator.skip();
successfulShardExecution();
}
}

private boolean checkMinimumVersion(List<SearchShardIterator> shardsIts) {
Expand All @@ -274,32 +275,6 @@ private boolean checkMinimumVersion(List<SearchShardIterator> shardsIts) {
return true;
}

private static boolean assertExecuteOnStartThread() {
// Ensure that the current code has the following stacktrace:
// AbstractSearchAsyncAction#start -> AbstractSearchAsyncAction#executePhase -> AbstractSearchAsyncAction#performPhaseOnShard
final StackTraceElement[] stackTraceElements = Thread.currentThread().getStackTrace();
assert stackTraceElements.length >= 6 : stackTraceElements;
int index = 0;
assert stackTraceElements[index++].getMethodName().equals("getStackTrace");
assert stackTraceElements[index++].getMethodName().equals("assertExecuteOnStartThread");
assert stackTraceElements[index++].getMethodName().equals("failOnUnavailable");
if (stackTraceElements[index].getMethodName().equals("performPhaseOnShard")) {
assert stackTraceElements[index].getClassName().endsWith("CanMatchPreFilterSearchPhase");
index++;
}
assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction");
assert stackTraceElements[index++].getMethodName().equals("run");

assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction");
assert stackTraceElements[index++].getMethodName().equals("executePhase");

assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction");
assert stackTraceElements[index++].getMethodName().equals("start");

assert stackTraceElements[index].getClassName().endsWith("AbstractSearchAsyncAction") == false;
return true;
}

private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
if (throttleConcurrentRequests) {
var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent(
Expand All @@ -318,7 +293,7 @@ private void doPerformPhaseOnShard(int shardIndex, SearchShardIterator shardIt,
public void innerOnResponse(Result result) {
try {
releasable.close();
onShardResult(result, shardIt);
onShardResult(result);
} catch (Exception exc) {
onShardFailure(shardIndex, shard, shardIt, exc);
}
Expand All @@ -341,7 +316,6 @@ public void onFailure(Exception e) {
}

private void failOnUnavailable(int shardIndex, SearchShardIterator shardIt) {
assert assertExecuteOnStartThread();
SearchShardTarget unassignedShard = new SearchShardTarget(null, shardIt.shardId(), shardIt.getClusterAlias());
onShardFailure(shardIndex, unassignedShard, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
}
Expand Down Expand Up @@ -398,7 +372,7 @@ protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextP
"Partial shards failure (unavailable: {}, successful: {}, skipped: {}, num-shards: {}, phase: {})",
discrepancy,
successfulOps.get(),
toSkipShardsIts.size(),
skippedCount,
getNumShards(),
currentPhase
);
Expand Down Expand Up @@ -537,9 +511,8 @@ void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Excepti
/**
* Executed once for every successful shard level request.
* @param result the result returned form the shard
* @param shardIt the shard iterator
*/
protected void onShardResult(Result result, SearchShardIterator shardIt) {
protected void onShardResult(Result result) {
assert result.getShardIndex() != -1 : "shard index is not set";
assert result.getSearchShardTarget() != null : "search shard target must not be null";
hasShardResponse.set(true);
Expand Down Expand Up @@ -637,7 +610,7 @@ private SearchResponse buildSearchResponse(
scrollId,
getNumShards(),
numSuccess,
toSkipShardsIts.size(),
skippedCount,
buildTookInMillis(),
failures,
clusters,
Expand Down Expand Up @@ -729,7 +702,7 @@ void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connecti
/**
* Executed once all shard results have been received and processed
* @see #onShardFailure(int, SearchShardTarget, Exception)
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
* @see #onShardResult(SearchPhaseResult)
*/
private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
executeNextPhase(getName(), this::getNextPhase);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
this.progressListener = task.getProgressListener();
// don't build the SearchShard list (can be expensive) if the SearchProgressListener won't use it
if (progressListener != SearchProgressListener.NOOP) {
notifyListShards(progressListener, clusters, request.source());
notifyListShards(progressListener, clusters, request, shardsIts);
}
this.client = client;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh

// don't build the SearchShard list (can be expensive) if the SearchProgressListener won't use it
if (progressListener != SearchProgressListener.NOOP) {
notifyListShards(progressListener, clusters, request.source());
notifyListShards(progressListener, clusters, request, shardsIts);
}
}

Expand All @@ -104,7 +104,7 @@ protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget
}

@Override
protected void onShardResult(SearchPhaseResult result, SearchShardIterator shardIt) {
protected void onShardResult(SearchPhaseResult result) {
QuerySearchResult queryResult = result.queryResult();
if (queryResult.isNull() == false
// disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard)
Expand All @@ -123,7 +123,7 @@ && getRequest().scroll() == null
}
bottomSortCollector.consumeTopDocs(topDocs, queryResult.sortValueFormats());
}
super.onShardResult(result, shardIt);
super.onShardResult(result);
}

static SearchPhase nextPhase(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,17 @@ public void testShardNotAvailableWithDisallowPartialFailures() {
ArraySearchPhaseResults<SearchPhaseResult> phaseResults = new ArraySearchPhaseResults<>(numShards);
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
// skip one to avoid the "all shards failed" failure.
SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null);
skipIterator.skip(true);
action.skipShard(skipIterator);
action.onShardResult(new SearchPhaseResult() {
@Override
public int getShardIndex() {
return 0;
}

@Override
public SearchShardTarget getSearchShardTarget() {
return new SearchShardTarget(null, null, null);
}
});
assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class));
SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException) exception.get();
assertEquals("Partial shards failure (" + (numShards - 1) + " shards unavailable)", searchPhaseExecutionException.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ protected void executePhaseOnShard(
SearchActionListener<SearchPhaseResult> listener
) {
onShardResult(new SearchPhaseResult() {
}, shardIt);
});
}

@Override
Expand Down