diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 5ef7664501b39..34993c7dca4c7 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -218,7 +218,7 @@ public final void start() { } @Override - public final void run() { + protected final void run() { for (final SearchShardIterator iterator : toSkipShardsIts) { assert iterator.skip(); skipShard(iterator); @@ -300,7 +300,7 @@ private static boolean assertExecuteOnStartThread() { return true; } - protected void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) { + private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final SearchShardTarget shard) { if (throttleConcurrentRequests) { var pendingExecutions = pendingExecutionsPerNode.computeIfAbsent( shard.getNodeId(), @@ -363,7 +363,7 @@ protected abstract void executePhaseOnShard( * of the next phase. If there are no successful operations in the context when this method is executed the search is aborted and * a response is returned to the user indicating that all shards have failed. */ - protected void executeNextPhase(SearchPhase currentPhase, Supplier nextPhaseSupplier) { + protected void executeNextPhase(String currentPhase, Supplier nextPhaseSupplier) { /* This is the main search phase transition where we move to the next phase. If all shards * failed or if there was a failure and partial results are not allowed, then we immediately * fail. Otherwise we continue to the next phase. @@ -374,7 +374,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier Throwable cause = shardSearchFailures.length == 0 ? null : ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; - logger.debug(() -> "All shards failed for phase: [" + currentPhase.getName() + "]", cause); + logger.debug(() -> "All shards failed for phase: [" + currentPhase + "]", cause); onPhaseFailure(currentPhase, "all shards failed", cause); } else { Boolean allowPartialResults = request.allowPartialSearchResults(); @@ -387,7 +387,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier int numShardFailures = shardSearchFailures.length; shardSearchFailures = ExceptionsHelper.groupBy(shardSearchFailures); Throwable cause = ElasticsearchException.guessRootCauses(shardSearchFailures[0].getCause())[0]; - logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase.getName()), cause); + logger.debug(() -> format("%s shards failed for phase: [%s]", numShardFailures, currentPhase), cause); } onPhaseFailure(currentPhase, "Partial shards failure", null); } else { @@ -400,7 +400,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier successfulOps.get(), toSkipShardsIts.size(), getNumShards(), - currentPhase.getName() + currentPhase ); } onPhaseFailure(currentPhase, "Partial shards failure (" + discrepancy + " shards unavailable)", null); @@ -414,7 +414,7 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier .collect(Collectors.joining(",")); logger.trace( "[{}] Moving to next phase: [{}], based on results from: {} (cluster state version: {})", - currentPhase.getName(), + currentPhase, nextPhase.getName(), resultsFrom, clusterStateVersion @@ -427,11 +427,11 @@ protected void executeNextPhase(SearchPhase currentPhase, Supplier private void executePhase(SearchPhase phase) { try { phase.run(); - } catch (Exception e) { + } catch (RuntimeException e) { if (logger.isDebugEnabled()) { logger.debug(() -> format("Failed to execute [%s] while moving to [%s] phase", request, phase.getName()), e); } - onPhaseFailure(phase, "", e); + onPhaseFailure(phase.getName(), "", e); } } @@ -686,8 +686,8 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At * @param msg an optional message * @param cause the cause of the phase failure */ - public void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) { - raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures())); + public void onPhaseFailure(String phase, String msg, Throwable cause) { + raisePhaseFailure(new SearchPhaseExecutionException(phase, msg, cause, buildShardFailures())); } /** @@ -732,7 +732,7 @@ void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connecti * @see #onShardResult(SearchPhaseResult, SearchShardIterator) */ private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() - executeNextPhase(this, this::getNextPhase); + executeNextPhase(getName(), this::getNextPhase); } /** diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index cc8c4becea9a9..d67e656773495 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -8,8 +8,15 @@ */ package org.elasticsearch.action.search; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -27,9 +34,11 @@ import org.elasticsearch.transport.Transport; import java.util.ArrayList; +import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; import java.util.List; -import java.util.function.Function; +import java.util.Map; /** * This search phase fans out to every shards to execute a distributed search with a pre-collected distributed frequencies for all @@ -38,53 +47,50 @@ * operation. * @see CountedCollector#onFailure(int, SearchShardTarget, Exception) */ -final class DfsQueryPhase extends SearchPhase { +class DfsQueryPhase extends SearchPhase { + + public static final String NAME = "dfs_query"; + private final SearchPhaseResults queryResult; - private final List searchResults; - private final AggregatedDfs dfs; - private final List knnResults; - private final Function, SearchPhase> nextPhaseFactory; + private final Client client; private final AbstractSearchAsyncAction context; - private final SearchTransportService searchTransportService; private final SearchProgressListener progressListener; - DfsQueryPhase( - List searchResults, - AggregatedDfs dfs, - List knnResults, - SearchPhaseResults queryResult, - Function, SearchPhase> nextPhaseFactory, - AbstractSearchAsyncAction context - ) { - super("dfs_query"); + DfsQueryPhase(SearchPhaseResults queryResult, Client client, AbstractSearchAsyncAction context) { + super(NAME); this.progressListener = context.getTask().getProgressListener(); this.queryResult = queryResult; - this.searchResults = searchResults; - this.dfs = dfs; - this.knnResults = knnResults; - this.nextPhaseFactory = nextPhaseFactory; + this.client = client; this.context = context; - this.searchTransportService = context.getSearchTransport(); } + // protected for testing + protected SearchPhase nextPhase(AggregatedDfs dfs) { + return SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResult, dfs); + } + + @SuppressWarnings("unchecked") @Override - public void run() { + protected void run() { + List searchResults = (List) context.results.getAtomicArray().asList(); + AggregatedDfs dfs = aggregateDfs(searchResults); // TODO we can potentially also consume the actual per shard results from the initial phase here in the aggregateDfs // to free up memory early final CountedCollector counter = new CountedCollector<>( queryResult, searchResults.size(), - () -> context.executeNextPhase(this, () -> nextPhaseFactory.apply(queryResult)), + () -> context.executeNextPhase(NAME, () -> nextPhase(dfs)), context ); + List knnResults = mergeKnnResults(context.getRequest(), searchResults); for (final DfsSearchResult dfsResult : searchResults) { final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget(); final int shardIndex = dfsResult.getShardIndex(); QuerySearchRequest querySearchRequest = new QuerySearchRequest( context.getOriginalIndices(shardIndex), dfsResult.getContextId(), - rewriteShardSearchRequest(dfsResult.getShardSearchRequest()), + rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()), dfs ); final Transport.Connection connection; @@ -94,11 +100,8 @@ public void run() { shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter); continue; } - searchTransportService.sendExecuteQuery( - connection, - querySearchRequest, - context.getTask(), - new SearchActionListener<>(shardTarget, shardIndex) { + context.getSearchTransport() + .sendExecuteQuery(connection, querySearchRequest, context.getTask(), new SearchActionListener<>(shardTarget, shardIndex) { @Override protected void innerOnResponse(QuerySearchResult response) { @@ -106,7 +109,7 @@ protected void innerOnResponse(QuerySearchResult response) { response.setSearchProfileDfsPhaseResult(dfsResult.searchProfileDfsPhaseResult()); counter.onResult(response); } catch (Exception e) { - context.onPhaseFailure(DfsQueryPhase.this, "", e); + context.onPhaseFailure(NAME, "", e); } } @@ -123,8 +126,7 @@ public void onFailure(Exception exception) { } } } - } - ); + }); } } @@ -141,7 +143,7 @@ private void shardFailure( } // package private for testing - ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { + ShardSearchRequest rewriteShardSearchRequest(List knnResults, ShardSearchRequest request) { SearchSourceBuilder source = request.source(); if (source == null || source.knnSearch().isEmpty()) { return request; @@ -177,4 +179,95 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { return request; } + + private static List mergeKnnResults(SearchRequest request, List dfsSearchResults) { + if (request.hasKnnSearch() == false) { + return null; + } + SearchSourceBuilder source = request.source(); + List> topDocsLists = new ArrayList<>(source.knnSearch().size()); + List> nestedPath = new ArrayList<>(source.knnSearch().size()); + for (int i = 0; i < source.knnSearch().size(); i++) { + topDocsLists.add(new ArrayList<>()); + nestedPath.add(new SetOnce<>()); + } + + for (DfsSearchResult dfsSearchResult : dfsSearchResults) { + if (dfsSearchResult.knnResults() != null) { + for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) { + DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i); + ScoreDoc[] scoreDocs = knnResults.scoreDocs(); + TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO); + TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs); + SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex()); + topDocsLists.get(i).add(shardTopDocs); + nestedPath.get(i).trySet(knnResults.getNestedPath()); + } + } + } + + List mergedResults = new ArrayList<>(source.knnSearch().size()); + for (int i = 0; i < source.knnSearch().size(); i++) { + TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0])); + mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs)); + } + return mergedResults; + } + + private static AggregatedDfs aggregateDfs(Collection results) { + Map termStatistics = new HashMap<>(); + Map fieldStatistics = new HashMap<>(); + long aggMaxDoc = 0; + for (DfsSearchResult lEntry : results) { + final Term[] terms = lEntry.terms(); + final TermStatistics[] stats = lEntry.termStatistics(); + assert terms.length == stats.length; + for (int i = 0; i < terms.length; i++) { + assert terms[i] != null; + if (stats[i] == null) { + continue; + } + TermStatistics existing = termStatistics.get(terms[i]); + if (existing != null) { + assert terms[i].bytes().equals(existing.term()); + termStatistics.put( + terms[i], + new TermStatistics( + existing.term(), + existing.docFreq() + stats[i].docFreq(), + existing.totalTermFreq() + stats[i].totalTermFreq() + ) + ); + } else { + termStatistics.put(terms[i], stats[i]); + } + + } + + assert lEntry.fieldStatistics().containsKey(null) == false; + for (var entry : lEntry.fieldStatistics().entrySet()) { + String key = entry.getKey(); + CollectionStatistics value = entry.getValue(); + if (value == null) { + continue; + } + assert key != null; + CollectionStatistics existing = fieldStatistics.get(key); + if (existing != null) { + CollectionStatistics merged = new CollectionStatistics( + key, + existing.maxDoc() + value.maxDoc(), + existing.docCount() + value.docCount(), + existing.sumTotalTermFreq() + value.sumTotalTermFreq(), + existing.sumDocFreq() + value.sumDocFreq() + ); + fieldStatistics.put(key, merged); + } else { + fieldStatistics.put(key, value); + } + } + aggMaxDoc += lEntry.maxDoc(); + } + return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java index e8d94c32bdcc7..8055ebb1a7358 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java @@ -12,34 +12,47 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import java.util.Iterator; import java.util.List; -import java.util.function.Supplier; /** * This search phase is an optional phase that will be executed once all hits are fetched from the shards that executes * field-collapsing on the inner hits. This phase only executes if field collapsing is requested in the search request and otherwise * forwards to the next phase immediately. */ -final class ExpandSearchPhase extends SearchPhase { +class ExpandSearchPhase extends SearchPhase { + + static final String NAME = "expand"; + private final AbstractSearchAsyncAction context; - private final SearchHits searchHits; - private final Supplier nextPhase; + private final SearchResponseSections searchResponseSections; + private final AtomicArray queryPhaseResults; - ExpandSearchPhase(AbstractSearchAsyncAction context, SearchHits searchHits, Supplier nextPhase) { - super("expand"); + ExpandSearchPhase( + AbstractSearchAsyncAction context, + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + super(NAME); this.context = context; - this.searchHits = searchHits; - this.nextPhase = nextPhase; + this.searchResponseSections = searchResponseSections; + this.queryPhaseResults = queryPhaseResults; + } + + // protected for tests + protected SearchPhase nextPhase() { + return new FetchLookupFieldsPhase(context, searchResponseSections, queryPhaseResults); } /** @@ -51,15 +64,16 @@ private boolean isCollapseRequest() { } @Override - public void run() { + protected void run() { + var searchHits = searchResponseSections.hits(); if (isCollapseRequest() == false || searchHits.getHits().length == 0) { onPhaseDone(); } else { - doRun(); + doRun(searchHits); } } - private void doRun() { + private void doRun(SearchHits searchHits) { SearchRequest searchRequest = context.getRequest(); CollapseBuilder collapseBuilder = searchRequest.source().collapse(); final List innerHitBuilders = collapseBuilder.getInnerHits(); @@ -123,7 +137,7 @@ private void doRun() { } private void phaseFailure(Exception ex) { - context.onPhaseFailure(this, "failed to expand hits", ex); + context.onPhaseFailure(NAME, "failed to expand hits", ex); } private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilder options, CollapseBuilder innerCollapseBuilder) { @@ -168,6 +182,6 @@ private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilde } private void onPhaseDone() { - context.executeNextPhase(this, nextPhase); + context.executeNextPhase(NAME, this::nextPhase); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java index d8671bcadf86d..9aba4efa03bf4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchLookupFieldsPhase.java @@ -33,6 +33,9 @@ * @see org.elasticsearch.index.mapper.LookupRuntimeFieldType */ final class FetchLookupFieldsPhase extends SearchPhase { + + static final String NAME = "fetch_lookup_fields"; + private final AbstractSearchAsyncAction context; private final SearchResponseSections searchResponse; private final AtomicArray queryResults; @@ -42,15 +45,13 @@ final class FetchLookupFieldsPhase extends SearchPhase { SearchResponseSections searchResponse, AtomicArray queryResults ) { - super("fetch_lookup_fields"); + super(NAME); this.context = context; this.searchResponse = searchResponse; this.queryResults = queryResults; } - private record Cluster(String clusterAlias, List hitsWithLookupFields, List lookupFields) { - - } + private record Cluster(String clusterAlias, List hitsWithLookupFields, List lookupFields) {} private static List groupLookupFieldsByClusterAlias(SearchHits searchHits) { final Map> perClusters = new HashMap<>(); @@ -74,10 +75,10 @@ private static List groupLookupFieldsByClusterAlias(SearchHits searchHi } @Override - public void run() { + protected void run() { final List clusters = groupLookupFieldsByClusterAlias(searchResponse.hits); if (clusters.isEmpty()) { - context.sendSearchResponse(searchResponse, queryResults); + sendResponse(); return; } doRun(clusters); @@ -129,16 +130,20 @@ public void onResponse(MultiSearchResponse items) { } } if (failure != null) { - context.onPhaseFailure(FetchLookupFieldsPhase.this, "failed to fetch lookup fields", failure); + onFailure(failure); } else { - context.sendSearchResponse(searchResponse, queryResults); + sendResponse(); } } @Override public void onFailure(Exception e) { - context.onPhaseFailure(FetchLookupFieldsPhase.this, "failed to fetch lookup fields", e); + context.onPhaseFailure(NAME, "failed to fetch lookup fields", e); } }); } + + private void sendResponse() { + context.sendSearchResponse(searchResponse, queryResults); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 31125e6d8b3c5..e63a3ef5b979f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -27,15 +27,16 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; /** * This search phase merges the query results from the previous phase together and calculates the topN hits for this search. * Then it reaches out to all relevant shards to fetch the topN hits. */ -final class FetchSearchPhase extends SearchPhase { + +class FetchSearchPhase extends SearchPhase { + static final String NAME = "fetch"; + private final AtomicArray searchPhaseShardResults; - private final BiFunction, SearchPhase> nextPhaseFactory; private final AbstractSearchAsyncAction context; private final Logger logger; private final SearchProgressListener progressListener; @@ -50,27 +51,7 @@ final class FetchSearchPhase extends SearchPhase { AbstractSearchAsyncAction context, @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { - this( - resultConsumer, - aggregatedDfs, - context, - reducedQueryPhase, - (response, queryPhaseResults) -> new ExpandSearchPhase( - context, - response.hits, - () -> new FetchLookupFieldsPhase(context, response, queryPhaseResults) - ) - ); - } - - FetchSearchPhase( - SearchPhaseResults resultConsumer, - AggregatedDfs aggregatedDfs, - AbstractSearchAsyncAction context, - @Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase, - BiFunction, SearchPhase> nextPhaseFactory - ) { - super("fetch"); + super(NAME); if (context.getNumShards() != resultConsumer.getNumShards()) { throw new IllegalStateException( "number of shards must match the length of the query results but doesn't:" @@ -81,7 +62,6 @@ final class FetchSearchPhase extends SearchPhase { } this.searchPhaseShardResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; - this.nextPhaseFactory = nextPhaseFactory; this.context = context; this.logger = context.getLogger(); this.progressListener = context.getTask().getProgressListener(); @@ -89,8 +69,13 @@ final class FetchSearchPhase extends SearchPhase { this.resultConsumer = reducedQueryPhase == null ? resultConsumer : null; } + // protected for tests + protected SearchPhase nextPhase(SearchResponseSections searchResponseSections, AtomicArray queryPhaseResults) { + return new ExpandSearchPhase(context, searchResponseSections, queryPhaseResults); + } + @Override - public void run() { + protected void run() { context.execute(new AbstractRunnable() { @Override @@ -100,7 +85,7 @@ protected void doRun() throws Exception { @Override public void onFailure(Exception e) { - context.onPhaseFailure(FetchSearchPhase.this, "", e); + context.onPhaseFailure(NAME, "", e); } }); } @@ -112,7 +97,7 @@ private void innerRun() throws Exception { final int numShards = context.getNumShards(); // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. - final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 + final boolean queryAndFetchOptimization = numShards == 1 && context.getRequest().hasKnnSearch() == false && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); @@ -127,7 +112,7 @@ private void innerRun() throws Exception { // we have to release contexts here to free up resources searchPhaseShardResults.asList() .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context)); - moveToNextPhase(new AtomicArray<>(numShards), reducedQueryPhase); + moveToNextPhase(new AtomicArray<>(0), reducedQueryPhase); } else { innerRunFetch(scoreDocs, numShards, reducedQueryPhase); } @@ -228,7 +213,7 @@ public void innerOnResponse(FetchSearchResult result) { progressListener.notifyFetchResult(shardIndex); counter.onResult(result); } catch (Exception e) { - context.onPhaseFailure(FetchSearchPhase.this, "", e); + context.onPhaseFailure(NAME, "", e); } } @@ -275,10 +260,10 @@ private void moveToNextPhase( AtomicArray fetchResultsArr, SearchPhaseController.ReducedQueryPhase reducedQueryPhase ) { - context.executeNextPhase(this, () -> { + context.executeNextPhase(NAME, () -> { var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr); context.addReleasable(resp); - return nextPhaseFactory.apply(resp, searchPhaseShardResults); + return nextPhase(resp, searchPhaseShardResults); }); } diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 199228c9f992c..e9302883457e1 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -37,6 +37,8 @@ */ public class RankFeaturePhase extends SearchPhase { + static final String NAME = "rank-feature"; + private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class); private final AbstractSearchAsyncAction context; final SearchPhaseResults queryPhaseResults; @@ -51,7 +53,7 @@ public class RankFeaturePhase extends SearchPhase { AbstractSearchAsyncAction context, RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext ) { - super("rank-feature"); + super(NAME); assert rankFeaturePhaseRankCoordinatorContext != null; this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext; if (context.getNumShards() != queryPhaseResults.getNumShards()) { @@ -71,7 +73,7 @@ public class RankFeaturePhase extends SearchPhase { } @Override - public void run() { + protected void run() { context.execute(new AbstractRunnable() { @Override protected void doRun() throws Exception { @@ -84,7 +86,7 @@ protected void doRun() throws Exception { @Override public void onFailure(Exception e) { - context.onPhaseFailure(RankFeaturePhase.this, "", e); + context.onPhaseFailure(NAME, "", e); } }); } @@ -139,7 +141,7 @@ protected void innerOnResponse(RankFeatureResult response) { progressListener.notifyRankFeatureResult(shardIndex); rankRequestCounter.onResult(response); } catch (Exception e) { - context.onPhaseFailure(RankFeaturePhase.this, "", e); + context.onPhaseFailure(NAME, "", e); } } @@ -194,7 +196,7 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) { @Override public void onFailure(Exception e) { - context.onPhaseFailure(RankFeaturePhase.this, "Computing updated ranks for results failed", e); + context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e); } } ); @@ -239,6 +241,6 @@ private float maxScore(ScoreDoc[] scoreDocs) { } void moveToNextPhase(SearchPhaseResults phaseResults, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) { - context.executeNextPhase(this, () -> new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase)); + context.executeNextPhase(NAME, () -> new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase)); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 056806fbb0b00..dd97f02dd8f40 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -10,29 +10,16 @@ package org.elasticsearch.action.search; import org.apache.logging.log4j.Logger; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.CollectionStatistics; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TermStatistics; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; -import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.dfs.AggregatedDfs; -import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.transport.Transport; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; @@ -102,111 +89,11 @@ protected void executePhaseOnShard( @Override protected SearchPhase getNextPhase() { - final List dfsSearchResults = results.getAtomicArray().asList(); - final AggregatedDfs aggregatedDfs = aggregateDfs(dfsSearchResults); - return new DfsQueryPhase( - dfsSearchResults, - aggregatedDfs, - mergeKnnResults(getRequest(), dfsSearchResults), - queryPhaseResultConsumer, - (queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, this, queryResults, aggregatedDfs), - this - ); + return new DfsQueryPhase(queryPhaseResultConsumer, client, this); } @Override protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) { progressListener.notifyQueryFailure(shardIndex, shardTarget, exc); } - - private static List mergeKnnResults(SearchRequest request, List dfsSearchResults) { - if (request.hasKnnSearch() == false) { - return null; - } - SearchSourceBuilder source = request.source(); - List> topDocsLists = new ArrayList<>(source.knnSearch().size()); - List> nestedPath = new ArrayList<>(source.knnSearch().size()); - for (int i = 0; i < source.knnSearch().size(); i++) { - topDocsLists.add(new ArrayList<>()); - nestedPath.add(new SetOnce<>()); - } - - for (DfsSearchResult dfsSearchResult : dfsSearchResults) { - if (dfsSearchResult.knnResults() != null) { - for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) { - DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i); - ScoreDoc[] scoreDocs = knnResults.scoreDocs(); - TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO); - TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs); - SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex()); - topDocsLists.get(i).add(shardTopDocs); - nestedPath.get(i).trySet(knnResults.getNestedPath()); - } - } - } - - List mergedResults = new ArrayList<>(source.knnSearch().size()); - for (int i = 0; i < source.knnSearch().size(); i++) { - TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0])); - mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs)); - } - return mergedResults; - } - - private static AggregatedDfs aggregateDfs(Collection results) { - Map termStatistics = new HashMap<>(); - Map fieldStatistics = new HashMap<>(); - long aggMaxDoc = 0; - for (DfsSearchResult lEntry : results) { - final Term[] terms = lEntry.terms(); - final TermStatistics[] stats = lEntry.termStatistics(); - assert terms.length == stats.length; - for (int i = 0; i < terms.length; i++) { - assert terms[i] != null; - if (stats[i] == null) { - continue; - } - TermStatistics existing = termStatistics.get(terms[i]); - if (existing != null) { - assert terms[i].bytes().equals(existing.term()); - termStatistics.put( - terms[i], - new TermStatistics( - existing.term(), - existing.docFreq() + stats[i].docFreq(), - existing.totalTermFreq() + stats[i].totalTermFreq() - ) - ); - } else { - termStatistics.put(terms[i], stats[i]); - } - - } - - assert lEntry.fieldStatistics().containsKey(null) == false; - for (var entry : lEntry.fieldStatistics().entrySet()) { - String key = entry.getKey(); - CollectionStatistics value = entry.getValue(); - if (value == null) { - continue; - } - assert key != null; - CollectionStatistics existing = fieldStatistics.get(key); - if (existing != null) { - CollectionStatistics merged = new CollectionStatistics( - key, - existing.maxDoc() + value.maxDoc(), - existing.docCount() + value.docCount(), - existing.sumTotalTermFreq() + value.sumTotalTermFreq(), - existing.sumDocFreq() + value.sumDocFreq() - ); - fieldStatistics.put(key, merged); - } else { - fieldStatistics.put(key, value); - } - } - aggMaxDoc += lEntry.maxDoc(); - } - return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc); - } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java index e30f1db33043d..b4e915cd655a8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhase.java @@ -8,26 +8,25 @@ */ package org.elasticsearch.action.search; -import org.elasticsearch.core.CheckedRunnable; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.transport.Transport; -import java.io.IOException; -import java.io.UncheckedIOException; import java.util.List; import java.util.Objects; /** * Base class for all individual search phases like collecting distributed frequencies, fetching documents, querying shards. */ -abstract class SearchPhase implements CheckedRunnable { +abstract class SearchPhase { private final String name; protected SearchPhase(String name) { this.name = Objects.requireNonNull(name, "name must not be null"); } + protected abstract void run(); + /** * Returns the phases name. */ @@ -36,11 +35,7 @@ public String getName() { } public void start() { - try { - run(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + run(); } private static String makeMissingShardsError(StringBuilder missingShards) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java index ca8d7138c1c81..53da76d96e405 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java @@ -46,7 +46,7 @@ * fan out to nodes and execute the query part of the scroll request. Subclasses can for instance * run separate fetch phases etc. */ -abstract class SearchScrollAsyncAction implements Runnable { +abstract class SearchScrollAsyncAction { protected final Logger logger; protected final ActionListener listener; protected final ParsedScrollId scrollId; @@ -235,7 +235,7 @@ protected SearchPhase sendResponsePhase( ) { return new SearchPhase("fetch") { @Override - public void run() { + protected void run() { sendResponse(queryPhase, fetchResults); } }; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index c9728a95cc526..29822e596356f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -65,7 +65,7 @@ protected void executeInitialPhase( protected SearchPhase moveToNextPhase(BiFunction clusterNodeLookup) { return new SearchPhase("fetch") { @Override - public void run() { + protected void run() { final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = reducedScrollQueryPhase(queryResults.asList()); ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); if (scoreDocs.length == 0) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index e63fbb7973fb5..43f989c5efdb8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -264,7 +264,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase(getName()) { @Override - public void run() { + protected void run() { sendSearchResponse(SearchResponseSections.EMPTY_WITH_TOTAL_HITS, results.getAtomicArray()); } }; diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java index cffba76988f7d..b232cd16ba65e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java @@ -91,7 +91,7 @@ public void onFailure(Exception e) { }; try { ParsedScrollId scrollId = parseScrollId(request.scrollId()); - Runnable action = switch (scrollId.getType()) { + var action = switch (scrollId.getType()) { case QUERY_THEN_FETCH_TYPE -> new SearchScrollQueryThenFetchAsyncAction( logger, clusterService, diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index cdd35b64661ea..4ec3e972ae61b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -204,12 +204,7 @@ public void testOnPhaseFailure() { List> nodeLookups = new ArrayList<>(); ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, 0); AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); - action.onPhaseFailure(new SearchPhase("test") { - @Override - public void run() { - - } - }, "message", null); + action.onPhaseFailure("test", "message", null); assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class)); SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException) exception.get(); assertEquals("message", searchPhaseExecutionException.getMessage()); diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index dd648f1dfd65d..d06c09e9585a9 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.dfs.DfsKnnResults; import org.elasticsearch.search.dfs.DfsSearchResult; import org.elasticsearch.search.internal.AliasFilter; @@ -139,12 +140,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -225,12 +221,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -313,12 +304,7 @@ public void sendExecuteQuery( exc -> {} ) ) { - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(((QueryPhaseResultConsumer) response).results); - } - }, mockSearchPhaseContext); + DfsQueryPhase phase = makeDfsPhase(results, consumer, mockSearchPhaseContext, responseRef); assertEquals("dfs_query", phase.getName()); phase.run(); assertThat(mockSearchPhaseContext.failures, hasSize(1)); @@ -328,6 +314,29 @@ public void run() throws IOException { } } + private static DfsQueryPhase makeDfsPhase( + AtomicArray results, + SearchPhaseResults consumer, + MockSearchPhaseContext mockSearchPhaseContext, + AtomicReference> responseRef + ) { + int shards = mockSearchPhaseContext.numShards; + for (int i = 0; i < shards; i++) { + mockSearchPhaseContext.results.getAtomicArray().set(i, results.get(i)); + } + return new DfsQueryPhase(consumer, null, mockSearchPhaseContext) { + @Override + protected SearchPhase nextPhase(AggregatedDfs dfs) { + return new SearchPhase("test") { + @Override + public void run() { + responseRef.set(((QueryPhaseResultConsumer) consumer).results); + } + }; + } + }; + } + public void testRewriteShardSearchRequestWithRank() { List dkrs = List.of( new DfsKnnResults(null, new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1), new ScoreDoc(7, 0.1f, 2) }), @@ -338,7 +347,7 @@ public void testRewriteShardSearchRequestWithRank() { ); MockSearchPhaseContext mspc = new MockSearchPhaseContext(2); mspc.searchTransport = new SearchTransportService(null, null, null); - DfsQueryPhase dqp = new DfsQueryPhase(null, null, dkrs, mock(QueryPhaseResultConsumer.class), null, mspc); + DfsQueryPhase dqp = new DfsQueryPhase(mock(QueryPhaseResultConsumer.class), null, mspc); QueryBuilder bm25 = new TermQueryBuilder("field", "term"); SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) @@ -352,7 +361,7 @@ public void testRewriteShardSearchRequestWithRank() { SearchRequest sr = new SearchRequest().allowPartialSearchResults(true).source(ssb); ShardSearchRequest ssr = new ShardSearchRequest(null, sr, new ShardId("test", "testuuid", 1), 1, 1, null, 1.0f, 0, null); - dqp.rewriteShardSearchRequest(ssr); + dqp.rewriteShardSearchRequest(dkrs, ssr); KnnScoreDocQueryBuilder ksdqb0 = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) }, diff --git a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java index 5ba5a4f903e83..99b4cd5681557 100644 --- a/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/ExpandSearchPhaseTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.AbstractSearchTestCase; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -117,14 +118,11 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(collapseValue))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase( + mockSearchPhaseContext, + new SearchResponseSections(hits, null, null, false, null, null, 1), + null + ); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -148,7 +146,6 @@ public void run() { if (resp != null) { resp.decRef(); } - } } } @@ -206,15 +203,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit2 = new SearchHit(2, "ID2"); hit2.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(collapseValue))); SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } - } - }); + try (SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); assertThat(mockSearchPhaseContext.phaseFailure.get(), Matchers.instanceOf(RuntimeException.class)); assertEquals("boom", mockSearchPhaseContext.phaseFailure.get().getMessage()); @@ -243,14 +233,11 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL hit2.setDocumentField("someField", new DocumentField("someField", Collections.singletonList(null))); SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - try (var sections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { - mockSearchPhaseContext.sendSearchResponse(sections, null); - } - } - }); + ExpandSearchPhase phase = newExpandSearchPhase( + mockSearchPhaseContext, + new SearchResponseSections(hits, null, null, false, null, null, 1), + null + ); phase.run(); mockSearchPhaseContext.assertNoFailure(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -283,12 +270,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL ); SearchHits hits = SearchHits.empty(new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null); - } - }); + final SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1); + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -332,13 +315,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList("foo"))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(new SearchResponseSections(hits, null, null, false, null, null, 1), null); - } - }); + try (SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, null); phase.run(); mockSearchPhaseContext.assertNoFailure(); } finally { @@ -353,6 +331,26 @@ public void run() { } } + private static ExpandSearchPhase newExpandSearchPhase( + MockSearchPhaseContext mockSearchPhaseContext, + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return new ExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, queryPhaseResults) { + @Override + protected SearchPhase nextPhase() { + return new SearchPhase("test") { + @Override + public void run() { + try (searchResponseSections) { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, queryPhaseResults); + } + } + }; + } + }; + } + public void testExpandSearchRespectsOriginalPIT() { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); final PointInTimeBuilder pit = new PointInTimeBuilder(new BytesArray("foo")); @@ -381,16 +379,8 @@ void sendExecuteMultiSearch(MultiSearchRequest request, SearchTask task, ActionL SearchHit hit = new SearchHit(1, "ID"); hit.setDocumentField("someField", new DocumentField("someField", Collections.singletonList("foo"))); SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - try { - ExpandSearchPhase phase = new ExpandSearchPhase(mockSearchPhaseContext, hits, () -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse( - new SearchResponseSections(hits, null, null, false, null, null, 1), - new AtomicArray<>(0) - ); - } - }); + try (SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, null, null, 1)) { + ExpandSearchPhase phase = newExpandSearchPhase(mockSearchPhaseContext, searchResponseSections, new AtomicArray<>(0)); phase.run(); mockSearchPhaseContext.assertNoFailure(); } finally { diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index 7cb8c66ff8ee2..7d3903ad10ea4 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -134,13 +134,7 @@ public void testShortcutQueryAndFetchOptimization() throws Exception { numHits = 0; } SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -263,13 +257,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -373,13 +361,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -481,19 +463,21 @@ public void sendExecuteFetch( }; CountDownLatch latch = new CountDownLatch(1); SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - (searchResponse, scrollId) -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); - latch.countDown(); - } + FetchSearchPhase phase = new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponseSections, null); + latch.countDown(); + } + }; } - ); + }; assertEquals("fetch", phase.getName()); phase.run(); latch.await(); @@ -621,13 +605,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); assertNotNull(mockSearchPhaseContext.searchResponse.get()); @@ -641,6 +619,22 @@ public void sendExecuteFetch( } } + private static FetchSearchPhase getFetchSearchPhase( + SearchPhaseResults results, + MockSearchPhaseContext mockSearchPhaseContext, + SearchPhaseController.ReducedQueryPhase reducedQueryPhase + ) { + return new FetchSearchPhase(results, null, mockSearchPhaseContext, reducedQueryPhase) { + @Override + protected SearchPhase nextPhase( + SearchResponseSections searchResponseSections, + AtomicArray queryPhaseResults + ) { + return searchPhaseFactory(mockSearchPhaseContext).apply(searchResponseSections, queryPhaseResults); + } + }; + } + public void testCleanupIrrelevantContexts() throws Exception { // contexts that are not fetched should be cleaned up MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder()); @@ -723,13 +717,7 @@ public void sendExecuteFetch( } }; SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce(); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - reducedQueryPhase, - searchPhaseFactory(mockSearchPhaseContext) - ); + FetchSearchPhase phase = getFetchSearchPhase(results, mockSearchPhaseContext, reducedQueryPhase); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -764,7 +752,7 @@ private static BiFunction ) { return (searchResponse, scrollId) -> new SearchPhase("test") { @Override - public void run() { + protected void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); } }; diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index cc922e39b4d35..40a29d9e2b055 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -110,7 +110,7 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At } @Override - public void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) { + public void onPhaseFailure(String phase, String msg, Throwable cause) { phaseFailure.set(cause); } @@ -132,12 +132,12 @@ public SearchTransportService getSearchTransport() { } @Override - public void executeNextPhase(SearchPhase currentPhase, Supplier nextPhaseSupplier) { + public void executeNextPhase(String currentPhase, Supplier nextPhaseSupplier) { var nextPhase = nextPhaseSupplier.get(); try { nextPhase.run(); } catch (Exception e) { - onPhaseFailure(nextPhase, "phase failed", e); + onPhaseFailure(nextPhase.getName(), "phase failed", e); } } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index b06a024dd626c..afd3bee4c4ab8 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -139,7 +139,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { assertTrue(searchPhaseDidRun.compareAndSet(false, true)); latch.countDown(); } @@ -254,7 +254,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { assertTrue(searchPhaseDidRun.compareAndSet(false, true)); latch.countDown(); } @@ -362,7 +362,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { for (int i = 0; i < results.getNumShards(); i++) { TestSearchPhaseResult result = results.getAtomicArray().get(i); assertEquals(result.node.getId(), result.getSearchShardTarget().getNodeId()); @@ -496,7 +496,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { throw new RuntimeException("boom"); } }; @@ -608,7 +608,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { assertTrue(searchPhaseDidRun.compareAndSet(false, true)); latch.countDown(); } @@ -686,7 +686,7 @@ protected void executePhaseOnShard( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { assertTrue(searchPhaseDidRun.compareAndSet(false, true)); latch.countDown(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 7bb8b3d4d9133..cb9222dbe36d6 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -221,7 +221,7 @@ public void sendExecuteQuery( protected SearchPhase getNextPhase() { return new SearchPhase("test") { @Override - public void run() { + protected void run() { latch.countDown(); } }; diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java index c32255a4d4ed9..4b08bfa0de01d 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.Transport; -import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -87,7 +86,7 @@ protected SearchPhase moveToNextPhase(BiFunction assertEquals(1, movedCounter.incrementAndGet()); return new SearchPhase("test") { @Override - public void run() throws IOException { + protected void run() { latch.countDown(); } }; @@ -184,7 +183,7 @@ protected SearchPhase moveToNextPhase(BiFunction assertEquals(1, movedCounter.incrementAndGet()); return new SearchPhase("TEST_PHASE") { @Override - public void run() throws IOException { + protected void run() { throw new IllegalArgumentException("BOOM"); } }; @@ -262,7 +261,7 @@ protected SearchPhase moveToNextPhase(BiFunction assertEquals(1, movedCounter.incrementAndGet()); return new SearchPhase("test") { @Override - public void run() throws IOException { + protected void run() { latch.countDown(); } }; @@ -344,7 +343,7 @@ protected SearchPhase moveToNextPhase(BiFunction assertEquals(1, movedCounter.incrementAndGet()); return new SearchPhase("test") { @Override - public void run() throws IOException { + protected void run() { latch.countDown(); } };