Skip to content

[8.x] Remove unnecessary interfaces/abstractions from search phases (#120079) #126511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use inheritance instead of composition to simplify search phase trans…
…itions (#119272)

We only need the extensibility for testing and it's a lot easier to
reason about the code if we have explicit methods instead of overly
complicated composition with lots of redundant references being retained
all over the place.

-> lets simplify to inheritance and get shorter code that performs more
predictably (especially when it comes to memory) as a first step.
This also opens up the possibility of further simplifications and
removing more retained state/memory as we go through the search phases.
  • Loading branch information
original-brownbear committed Apr 9, 2025
commit 0031ac03e2a184a2e8e01ad0df61e92c66984fae
150 changes: 120 additions & 30 deletions server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
*/
package org.elasticsearch.action.search;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -27,9 +34,11 @@
import org.elasticsearch.transport.Transport;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
import java.util.Map;

/**
* This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all
Expand All @@ -38,56 +47,50 @@
* operation.
* @see CountedCollector#onFailure(int, SearchShardTarget, Exception)
*/
final class DfsQueryPhase extends SearchPhase {
class DfsQueryPhase extends SearchPhase {

public static final String NAME = "dfs_query";

private final SearchPhaseResults<SearchPhaseResult> queryResult;
private final List<DfsSearchResult> searchResults;
private final AggregatedDfs dfs;
private final List<DfsKnnResults> knnResults;
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final Client client;
private final AbstractSearchAsyncAction<?> context;
private final SearchTransportService searchTransportService;
private final SearchProgressListener progressListener;

DfsQueryPhase(
List<DfsSearchResult> searchResults,
AggregatedDfs dfs,
List<DfsKnnResults> knnResults,
SearchPhaseResults<SearchPhaseResult> queryResult,
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
AbstractSearchAsyncAction<?> context
) {
DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
super(NAME);
this.progressListener = context.getTask().getProgressListener();
this.queryResult = queryResult;
this.searchResults = searchResults;
this.dfs = dfs;
this.knnResults = knnResults;
this.nextPhaseFactory = nextPhaseFactory;
this.client = client;
this.context = context;
this.searchTransportService = context.getSearchTransport();
}

// protected for testing
protected SearchPhase nextPhase(AggregatedDfs dfs) {
return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs);
}

@SuppressWarnings("unchecked")
@Override
protected void run() {
List<DfsSearchResult> searchResults = (List<DfsSearchResult>) context.results.getAtomicArray().asList();
AggregatedDfs dfs = aggregateDfs(searchResults);
// TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs
// to free up memory early
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(
queryResult,
searchResults.size(),
() -> context.executeNextPhase(NAME, () -> nextPhaseFactory.apply(queryResult)),
() -> context.executeNextPhase(NAME, () -> nextPhase(dfs)),
context
);

List<DfsKnnResults> knnResults = mergeKnnResults(context.getRequest(), searchResults);
for (final DfsSearchResult dfsResult : searchResults) {
final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
final int shardIndex = dfsResult.getShardIndex();
QuerySearchRequest querySearchRequest = new QuerySearchRequest(
context.getOriginalIndices(shardIndex),
dfsResult.getContextId(),
rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()),
dfs
);
final Transport.Connection connection;
Expand All @@ -97,11 +100,8 @@ protected void run() {
shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
continue;
}
searchTransportService.sendExecuteQuery(
connection,
querySearchRequest,
context.getTask(),
new SearchActionListener<>(shardTarget, shardIndex) {
context.getSearchTransport()
.sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) {

@Override
protected void innerOnResponse(QuerySearchResult response) {
Expand All @@ -126,8 +126,7 @@ public void onFailure(Exception exception) {
}
}
}
}
);
});
}
}

Expand All @@ -144,7 +143,7 @@ private void shardFailure(
}

// package private for testing
ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
SearchSourceBuilder source = request.source();
if (source == null || source.knnSearch().isEmpty()) {
return request;
Expand Down Expand Up @@ -180,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {

return request;
}

private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
if (request.hasKnnSearch() == false) {
return null;
}
SearchSourceBuilder source = request.source();
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
for (int i = 0; i < source.knnSearch().size(); i++) {
topDocsLists.add(new ArrayList<>());
nestedPath.add(new SetOnce<>());
}

for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
if (dfsSearchResult.knnResults() != null) {
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
topDocsLists.get(i).add(shardTopDocs);
nestedPath.get(i).trySet(knnResults.getNestedPath());
}
}
}

List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
for (int i = 0; i < source.knnSearch().size(); i++) {
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
}
return mergedResults;
}

private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
Map<Term, TermStatistics> termStatistics = new HashMap<>();
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
long aggMaxDoc = 0;
for (DfsSearchResult lEntry : results) {
final Term[] terms = lEntry.terms();
final TermStatistics[] stats = lEntry.termStatistics();
assert terms.length == stats.length;
for (int i = 0; i < terms.length; i++) {
assert terms[i] != null;
if (stats[i] == null) {
continue;
}
TermStatistics existing = termStatistics.get(terms[i]);
if (existing != null) {
assert terms[i].bytes().equals(existing.term());
termStatistics.put(
terms[i],
new TermStatistics(
existing.term(),
existing.docFreq() + stats[i].docFreq(),
existing.totalTermFreq() + stats[i].totalTermFreq()
)
);
} else {
termStatistics.put(terms[i], stats[i]);
}

}

assert lEntry.fieldStatistics().containsKey(null) == false;
for (var entry : lEntry.fieldStatistics().entrySet()) {
String key = entry.getKey();
CollectionStatistics value = entry.getValue();
if (value == null) {
continue;
}
assert key != null;
CollectionStatistics existing = fieldStatistics.get(key);
if (existing != null) {
CollectionStatistics merged = new CollectionStatistics(
key,
existing.maxDoc() + value.maxDoc(),
existing.docCount() + value.docCount(),
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
existing.sumDocFreq() + value.sumDocFreq()
);
fieldStatistics.put(key, merged);
} else {
fieldStatistics.put(key, value);
}
}
aggMaxDoc += lEntry.maxDoc();
}
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,47 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;

import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;

/**
* This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes
* field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise
* forwards to the next phase immediately.
*/
final class ExpandSearchPhase extends SearchPhase {
class ExpandSearchPhase extends SearchPhase {

static final String NAME = "expand";

private final AbstractSearchAsyncAction<?> context;
private final SearchHits searchHits;
private final Supplier<SearchPhase> nextPhase;
private final SearchResponseSections searchResponseSections;
private final AtomicArray<SearchPhaseResult> queryPhaseResults;

ExpandSearchPhase(AbstractSearchAsyncAction<?> context, SearchHits searchHits, Supplier<SearchPhase> nextPhase) {
ExpandSearchPhase(
AbstractSearchAsyncAction<?> context,
SearchResponseSections searchResponseSections,
AtomicArray<SearchPhaseResult> queryPhaseResults
) {
super(NAME);
this.context = context;
this.searchHits = searchHits;
this.nextPhase = nextPhase;
this.searchResponseSections = searchResponseSections;
this.queryPhaseResults = queryPhaseResults;
}

// protected for tests
protected SearchPhase nextPhase() {
return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults);
}

/**
Expand All @@ -55,14 +65,15 @@ private boolean isCollapseRequest() {

@Override
protected void run() {
var searchHits = searchResponseSections.hits();
if (isCollapseRequest() == false || searchHits.getHits().length == 0) {
onPhaseDone();
} else {
doRun();
doRun(searchHits);
}
}

private void doRun() {
private void doRun(SearchHits searchHits) {
SearchRequest searchRequest = context.getRequest();
CollapseBuilder collapseBuilder = searchRequest.source().collapse();
final List<InnerHitBuilder> innerHitBuilders = collapseBuilder.getInnerHits();
Expand Down Expand Up @@ -171,6 +182,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde
}

private void onPhaseDone() {
context.executeNextPhase(NAME, nextPhase);
context.executeNextPhase(NAME, this::nextPhase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ final class FetchLookupFieldsPhase extends SearchPhase {
this.queryResults = queryResults;
}

private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {

}
private record Cluster(String clusterAlias, List<SearchHit> hitsWithLookupFields, List<LookupField> lookupFields) {}

private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHits) {
final Map<String, List<SearchHit>> perClusters = new HashMap<>();
Expand All @@ -80,7 +78,7 @@ private static List<Cluster> groupLookupFieldsByClusterAlias(SearchHits searchHi
protected void run() {
final List<Cluster> clusters = groupLookupFieldsByClusterAlias(searchResponse.hits);
if (clusters.isEmpty()) {
context.sendSearchResponse(searchResponse, queryResults);
sendResponse();
return;
}
doRun(clusters);
Expand Down Expand Up @@ -132,9 +130,9 @@ public void onResponse(MultiSearchResponse items) {
}
}
if (failure != null) {
context.onPhaseFailure(NAME, "failed to fetch lookup fields", failure);
onFailure(failure);
} else {
context.sendSearchResponse(searchResponse, queryResults);
sendResponse();
}
}

Expand All @@ -144,4 +142,8 @@ public void onFailure(Exception e) {
}
});
}

private void sendResponse() {
context.sendSearchResponse(searchResponse, queryResults);
}
}
Loading