Skip to content

[8.x] Remove unnecessary interfaces/abstractions from search phases (#120079) #126511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ public final void start() {
}

@Override
public final void run() {
protected final void run() {
for (final SearchShardIterator iterator : toSkipShardsIts) {
assert iterator.skip();
skipShard(iterator);
Expand Down Expand Up @@ -300,7 +300,7 @@ private static boolean assertExecuteOnStartThread() {
return true;
}

protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) {
if (throttleConcurrentRequests) {
var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent(
shard.getNodeId(),
Expand Down Expand Up @@ -363,7 +363,7 @@ protected abstract void executePhaseOnShard(
* of the next phase. If there are no successful operations in the context when this method is executed the search is aborted and
* a response is returned to the user indicating that all shards have failed.
*/
protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase> nextPhaseSupplier) {
protected void executeNextPhase(String currentPhase, Supplier<SearchPhase> nextPhaseSupplier) {
/* This is the main search phase transition where we move to the next phase. If all shards
* failed or if there was a failure and partial results are not allowed, then we immediately
* fail. Otherwise we continue to the next phase.
Expand All @@ -374,7 +374,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
Throwable cause = shardSearchFailures.length == 0
? null
: ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0];
logger.debug(() -> "All shards failed for phase: [" + currentPhase.getName() + "]", cause);
logger.debug(() -> "All shards failed for phase: [" + currentPhase + "]", cause);
onPhaseFailure(currentPhase, "all shards failed", cause);
} else {
Boolean allowPartialResults = request.allowPartialSearchResults();
Expand All @@ -387,7 +387,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
int numShardFailures = shardSearchFailures.length;
shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures);
Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0];
logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase.getName()), cause);
logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause);
}
onPhaseFailure(currentPhase, "Partial shards failure", null);
} else {
Expand All @@ -400,7 +400,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
successfulOps.get(),
toSkipShardsIts.size(),
getNumShards(),
currentPhase.getName()
currentPhase
);
}
onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null);
Expand All @@ -414,7 +414,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
.collect(Collectors.joining(","));
logger.trace(
"[{}] Moving to next phase: [{}], based on results from: {} (cluster state version: {})",
currentPhase.getName(),
currentPhase,
nextPhase.getName(),
resultsFrom,
clusterStateVersion
Expand All @@ -427,11 +427,11 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase>
private void executePhase(SearchPhase phase) {
try {
phase.run();
} catch (Exception e) {
} catch (RuntimeException e) {
if (logger.isDebugEnabled()) {
logger.debug(() -> format("Failed to execute [%s] while moving to [%s] phase", request, phase.getName()), e);
}
onPhaseFailure(phase, "", e);
onPhaseFailure(phase.getName(), "", e);
}
}

Expand Down Expand Up @@ -686,8 +686,8 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At
* @param msg an optional message
* @param cause the cause of the phase failure
*/
public void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures()));
public void onPhaseFailure(String phase, String msg, Throwable cause) {
raisePhaseFailure(new SearchPhaseExecutionException(phase, msg, cause, buildShardFailures()));
}

/**
Expand Down Expand Up @@ -732,7 +732,7 @@ void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connecti
* @see #onShardResult(SearchPhaseResult, SearchShardIterator)
*/
private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
executeNextPhase(this, this::getNextPhase);
executeNextPhase(getName(), this::getNextPhase);
}

/**
Expand Down
159 changes: 126 additions & 33 deletions server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
*/
package org.elasticsearch.action.search;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -27,9 +34,11 @@
import org.elasticsearch.transport.Transport;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
import java.util.Map;

/**
* This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
Expand All @@ -38,53 +47,50 @@
* operation.
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
class DfsQueryPhase extends SearchPhase {

public static final String NAME = "dfs_query";

private final SearchPhaseResults<SearchPhaseResult> queryResult;
private final List<DfsSearchResult> searchResults;
private final AggregatedDfs dfs;
private final List<DfsKnnResults> knnResults;
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final Client client;
private final AbstractSearchAsyncAction<?> context;
private final SearchTransportService searchTransportService;
private final SearchProgressListener progressListener;

DfsQueryPhase(
List<DfsSearchResult> searchResults,
AggregatedDfs dfs,
List<DfsKnnResults> knnResults,
SearchPhaseResults<SearchPhaseResult> queryResult,
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
AbstractSearchAsyncAction<?> context
) {
super("dfs_query");
DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
super(NAME);
this.progressListener = context.getTask().getProgressListener();
this.queryResult = queryResult;
this.searchResults = searchResults;
this.dfs = dfs;
this.knnResults = knnResults;
this.nextPhaseFactory = nextPhaseFactory;
this.client = client;
this.context = context;
this.searchTransportService = context.getSearchTransport();
}

// protected for testing
protected SearchPhase nextPhase(AggregatedDfs dfs) {
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
}

@SuppressWarnings("unchecked")
@Override
public void run() {
protected void run() {
List<DfsSearchResult> searchResults = (List<DfsSearchResult>) context.results.getAtomicArray().asList();
AggregatedDfs dfs = aggregateDfs(searchResults);
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
// to free up memory early
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
queryResult,
searchResults.size(),
() -> context.executeNextPhase(this, () -> nextPhaseFactory.apply(queryResult)),
() -> context.executeNextPhase(NAME, () -> nextPhase(dfs)),
context
);

List<DfsKnnResults> knnResults = mergeKnnResults(context.getRequest(), searchResults);
for (final DfsSearchResult dfsResult : searchResults) {
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
final int shardIndex = dfsResult.getShardIndex();
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
context.getOriginalIndices(shardIndex),
dfsResult.getContextId(),
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()),
dfs
);
final Transport.Connection connection;
Expand All @@ -94,19 +100,16 @@ public void run() {
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
continue;
}
searchTransportService.sendExecuteQuery(
connection,
querySearchRequest,
context.getTask(),
new SearchActionListener<>(shardTarget, shardIndex) {
context.getSearchTransport()
.sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) {

@Override
protected void innerOnResponse(QuerySearchResult response) {
try {
response.setSearchProfileDfsPhaseResult(dfsResult.searchProfileDfsPhaseResult());
counter.onResult(response);
} catch (Exception e) {
context.onPhaseFailure(DfsQueryPhase.this, "", e);
context.onPhaseFailure(NAME, "", e);
}
}

Expand All @@ -123,8 +126,7 @@ public void onFailure(Exception exception) {
}
}
}
}
);
});
}
}

Expand All @@ -141,7 +143,7 @@ private void shardFailure(
}

// package private for testing
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
SearchSourceBuilder source = request.source();
if (source == null || source.knnSearch().isEmpty()) {
return request;
Expand Down Expand Up @@ -177,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {

return request;
}

private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
if (request.hasKnnSearch() == false) {
return null;
}
SearchSourceBuilder source = request.source();
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
for (int i = 0; i < source.knnSearch().size(); i++) {
topDocsLists.add(new ArrayList<>());
nestedPath.add(new SetOnce<>());
}

for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
if (dfsSearchResult.knnResults() != null) {
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
topDocsLists.get(i).add(shardTopDocs);
nestedPath.get(i).trySet(knnResults.getNestedPath());
}
}
}

List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
for (int i = 0; i < source.knnSearch().size(); i++) {
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
}
return mergedResults;
}

private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
Map<Term, TermStatistics> termStatistics = new HashMap<>();
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
long aggMaxDoc = 0;
for (DfsSearchResult lEntry : results) {
final Term[] terms = lEntry.terms();
final TermStatistics[] stats = lEntry.termStatistics();
assert terms.length == stats.length;
for (int i = 0; i < terms.length; i++) {
assert terms[i] != null;
if (stats[i] == null) {
continue;
}
TermStatistics existing = termStatistics.get(terms[i]);
if (existing != null) {
assert terms[i].bytes().equals(existing.term());
termStatistics.put(
terms[i],
new TermStatistics(
existing.term(),
existing.docFreq() + stats[i].docFreq(),
existing.totalTermFreq() + stats[i].totalTermFreq()
)
);
} else {
termStatistics.put(terms[i], stats[i]);
}

}

assert lEntry.fieldStatistics().containsKey(null) == false;
for (var entry : lEntry.fieldStatistics().entrySet()) {
String key = entry.getKey();
CollectionStatistics value = entry.getValue();
if (value == null) {
continue;
}
assert key != null;
CollectionStatistics existing = fieldStatistics.get(key);
if (existing != null) {
CollectionStatistics merged = new CollectionStatistics(
key,
existing.maxDoc() + value.maxDoc(),
existing.docCount() + value.docCount(),
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
existing.sumDocFreq() + value.sumDocFreq()
);
fieldStatistics.put(key, merged);
} else {
fieldStatistics.put(key, value);
}
}
aggMaxDoc += lEntry.maxDoc();
}
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
}
}
Loading