2424import org .elasticsearch .common .lucene .Lucene ;
2525
2626import java .io .IOException ;
27+ import java .util .ArrayList ;
28+ import java .util .Collection ;
29+ import java .util .List ;
2730import java .util .Objects ;
31+ import java .util .concurrent .atomic .AtomicInteger ;
2832
2933/**
3034 * Top-level collector used in the query phase to perform top hits collection as well as aggs collection.
4044final class QueryPhaseCollector implements Collector {
4145 private final Collector aggsCollector ;
4246 private final Collector topDocsCollector ;
43- private final int terminateAfter ;
47+ private final TerminateAfterChecker terminateAfterChecker ;
4448 private final Weight postFilterWeight ;
4549 private final Float minScore ;
4650 private final boolean cacheScores ;
47-
48- private int numCollected ;
4951 private boolean terminatedAfter = false ;
5052
5153 QueryPhaseCollector (Collector topDocsCollector , Weight postFilterWeight , int terminateAfter , Collector aggsCollector , Float minScore ) {
54+ this (topDocsCollector , postFilterWeight , resolveTerminateAfterChecker (terminateAfter ), aggsCollector , minScore );
55+ }
56+
57+ QueryPhaseCollector (
58+ Collector topDocsCollector ,
59+ Weight postFilterWeight ,
60+ TerminateAfterChecker terminateAfterChecker ,
61+ Collector aggsCollector ,
62+ Float minScore
63+ ) {
5264 this .topDocsCollector = Objects .requireNonNull (topDocsCollector );
5365 this .postFilterWeight = postFilterWeight ;
54- if (terminateAfter < 0 ) {
55- throw new IllegalArgumentException ("terminateAfter must be greater than or equal to 0" );
56- }
57- this .terminateAfter = terminateAfter ;
66+ this .terminateAfterChecker = terminateAfterChecker ;
5867 this .aggsCollector = aggsCollector ;
5968 this .minScore = minScore ;
6069 this .cacheScores = aggsCollector != null && topDocsCollector .scoreMode ().needsScores () && aggsCollector .scoreMode ().needsScores ();
@@ -104,30 +113,16 @@ boolean isTerminatedAfter() {
104113 }
105114
106115 private boolean shouldCollectTopDocs (int doc , Scorable scorer , Bits postFilterBits ) throws IOException {
107- if (isDocWithinMinScore (scorer )) {
108- if (doesDocMatchPostFilter (doc , postFilterBits )) {
109- // terminate_after is purposely applied after post_filter, and terminates aggs collection based on number of filtered
110- // top hits that have been collected. Strange feature, but that has been behaviour for a long time.
111- applyTerminateAfter ();
112- return true ;
113- }
114- }
115- return false ;
116+ return isDocWithinMinScore (scorer ) && (postFilterBits == null || postFilterBits .get (doc ));
116117 }
117118
118119 private boolean isDocWithinMinScore (Scorable scorer ) throws IOException {
119120 return minScore == null || scorer .score () >= minScore ;
120121 }
121122
122- private static boolean doesDocMatchPostFilter (int doc , Bits postFilterBits ) {
123- return postFilterBits == null || postFilterBits .get (doc );
124- }
125-
126- private void applyTerminateAfter () {
127- if (terminateAfter > 0 && numCollected >= terminateAfter ) {
128- terminatedAfter = true ;
129- throw new CollectionTerminatedException ();
130- }
123+ private void earlyTerminate () {
124+ terminatedAfter = true ;
125+ throw new CollectionTerminatedException ();
131126 }
132127
133128 private Bits getPostFilterBits (LeafReaderContext context ) throws IOException {
@@ -140,12 +135,14 @@ private Bits getPostFilterBits(LeafReaderContext context) throws IOException {
140135
141136 @ Override
142137 public LeafCollector getLeafCollector (LeafReaderContext context ) throws IOException {
143- applyTerminateAfter ();
138+ if (terminateAfterChecker .isThresholdReached ()) {
139+ earlyTerminate ();
140+ }
144141 Bits postFilterBits = getPostFilterBits (context );
145142
146143 if (aggsCollector == null ) {
147144 final LeafCollector topDocsLeafCollector = topDocsCollector .getLeafCollector (context );
148- if (postFilterBits == null && terminateAfter == 0 && minScore == null ) {
145+ if (postFilterBits == null && terminateAfterChecker == NO_OP_TERMINATE_AFTER_CHECKER && minScore == null ) {
149146 // no need to wrap if we just need to collect unfiltered docs through leaf collector.
150147 // aggs collector was not originally provided so the overall score mode is that of the top docs collector
151148 return topDocsLeafCollector ;
@@ -182,7 +179,10 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept
182179 // if that the aggs collector early terminates while the top docs collector does not, we still need to wrap the leaf collector
183180 // to enforce that setMinCompetitiveScore is a no-op. Otherwise we may allow the top docs collector to skip non competitive
184181 // hits despite the score mode of the Collector did not allow it (because aggs don't support TOP_SCORES).
185- if (aggsLeafCollector == null && postFilterBits == null && terminateAfter == 0 && minScore == null ) {
182+ if (aggsLeafCollector == null
183+ && postFilterBits == null
184+ && terminateAfterChecker == NO_OP_TERMINATE_AFTER_CHECKER
185+ && minScore == null ) {
186186 // special case for early terminated aggs
187187 return new FilterLeafCollector (topDocsLeafCollector ) {
188188 @ Override
@@ -213,7 +213,7 @@ private class TopDocsLeafCollector implements LeafCollector {
213213
214214 TopDocsLeafCollector (Bits postFilterBits , LeafCollector topDocsLeafCollector ) {
215215 assert topDocsLeafCollector != null ;
216- assert postFilterBits != null || terminateAfter > 0 || minScore != null ;
216+ assert postFilterBits != null || terminateAfterChecker != NO_OP_TERMINATE_AFTER_CHECKER || minScore != null ;
217217 this .postFilterBits = postFilterBits ;
218218 this .topDocsLeafCollector = topDocsLeafCollector ;
219219 }
@@ -232,7 +232,11 @@ public DocIdSetIterator competitiveIterator() throws IOException {
232232 @ Override
233233 public void collect (int doc ) throws IOException {
234234 if (shouldCollectTopDocs (doc , scorer , postFilterBits )) {
235- numCollected ++;
235+ // terminate_after is purposely applied after post_filter, and terminates aggs collection based on number of filtered
236+ // top hits that have been collected. Strange feature, but that has been behaviour for a long time.
237+ if (terminateAfterChecker .incrementHitCountAndCheckThreshold ()) {
238+ earlyTerminate ();
239+ }
236240 topDocsLeafCollector .collect (doc );
237241 }
238242 }
@@ -278,7 +282,9 @@ public void collect(int doc) throws IOException {
278282 if (shouldCollectTopDocs (doc , scorer , postFilterBits )) {
279283 // we keep on counting and checking the terminate_after threshold so that we can terminate aggs collection
280284 // even if top docs collection early terminated
281- numCollected ++;
285+ if (terminateAfterChecker .incrementHitCountAndCheckThreshold ()) {
286+ earlyTerminate ();
287+ }
282288 if (topDocsLeafCollector != null ) {
283289 try {
284290 topDocsLeafCollector .collect (doc );
@@ -320,4 +326,135 @@ public DocIdSetIterator competitiveIterator() throws IOException {
320326 return null ;
321327 }
322328 }
329+
330+ static CollectorManager createManager (
331+ org .apache .lucene .search .CollectorManager <? extends Collector , Void > topDocsCollectorManager ,
332+ Weight postFilterWeight ,
333+ int terminateAfter ,
334+ org .apache .lucene .search .CollectorManager <? extends Collector , Void > aggsCollectorManager ,
335+ Float minScore
336+ ) {
337+ return new CollectorManager (
338+ topDocsCollectorManager ,
339+ postFilterWeight ,
340+ resolveTerminateAfterChecker (terminateAfter ),
341+ aggsCollectorManager ,
342+ minScore
343+ );
344+ }
345+
346+ private static TerminateAfterChecker resolveTerminateAfterChecker (int terminateAfter ) {
347+ if (terminateAfter < 0 ) {
348+ throw new IllegalArgumentException ("terminateAfter must be greater than or equal to 0" );
349+ }
350+ return terminateAfter == 0 ? NO_OP_TERMINATE_AFTER_CHECKER : new GlobalTerminateAfterChecker (terminateAfter );
351+ }
352+
353+ private abstract static class TerminateAfterChecker {
354+ abstract boolean isThresholdReached ();
355+
356+ abstract boolean incrementHitCountAndCheckThreshold ();
357+ }
358+
359+ private static final class GlobalTerminateAfterChecker extends TerminateAfterChecker {
360+ private final int terminateAfter ;
361+ private final AtomicInteger numCollected = new AtomicInteger ();
362+
363+ GlobalTerminateAfterChecker (int terminateAfter ) {
364+ assert terminateAfter > 0 ;
365+ this .terminateAfter = terminateAfter ;
366+ }
367+
368+ boolean isThresholdReached () {
369+ return numCollected .getAcquire () >= terminateAfter ;
370+ }
371+
372+ boolean incrementHitCountAndCheckThreshold () {
373+ return numCollected .incrementAndGet () > terminateAfter ;
374+ }
375+ }
376+
377+ // no needless counting when terminate_after is not set
378+ private static final TerminateAfterChecker NO_OP_TERMINATE_AFTER_CHECKER = new TerminateAfterChecker () {
379+ @ Override
380+ boolean isThresholdReached () {
381+ return false ;
382+ }
383+
384+ @ Override
385+ boolean incrementHitCountAndCheckThreshold () {
386+ return false ;
387+ }
388+ };
389+
390+ /**
391+ * {@link org.apache.lucene.search.CollectorManager} implementation based on {@link QueryPhaseCollector}.
392+ * Wraps two {@link org.apache.lucene.search.CollectorManager}s: one required for top docs collection, and another one optional for
393+ * aggs collection. Applies terminate_after consistently across the different collectors by sharing an atomic counter of collected docs.
394+ */
395+ static class CollectorManager implements org .apache .lucene .search .CollectorManager <QueryPhaseCollector , Void > {
396+ private final Weight postFilterWeight ;
397+ private final TerminateAfterChecker terminateAfterChecker ;
398+ private final Float minScore ;
399+ private final org .apache .lucene .search .CollectorManager <? extends Collector , Void > topDocsCollectorManager ;
400+ private final org .apache .lucene .search .CollectorManager <? extends Collector , Void > aggsCollectorManager ;
401+
402+ private boolean terminatedAfter ;
403+
404+ CollectorManager (
405+ org .apache .lucene .search .CollectorManager <? extends Collector , Void > topDocsCollectorManager ,
406+ Weight postFilterWeight ,
407+ TerminateAfterChecker terminateAfterChecker ,
408+ org .apache .lucene .search .CollectorManager <? extends Collector , Void > aggsCollectorManager ,
409+ Float minScore
410+ ) {
411+ this .topDocsCollectorManager = topDocsCollectorManager ;
412+ this .postFilterWeight = postFilterWeight ;
413+ this .terminateAfterChecker = terminateAfterChecker ;
414+ this .aggsCollectorManager = aggsCollectorManager ;
415+ this .minScore = minScore ;
416+ }
417+
418+ @ Override
419+ public QueryPhaseCollector newCollector () throws IOException {
420+ Collector aggsCollector = aggsCollectorManager == null ? null : aggsCollectorManager .newCollector ();
421+ return new QueryPhaseCollector (
422+ topDocsCollectorManager .newCollector (),
423+ postFilterWeight ,
424+ terminateAfterChecker ,
425+ aggsCollector ,
426+ minScore
427+ );
428+ }
429+
430+ @ Override
431+ public Void reduce (Collection <QueryPhaseCollector > collectors ) throws IOException {
432+ List <Collector > topDocsCollectors = new ArrayList <>();
433+ List <Collector > aggsCollectors = new ArrayList <>();
434+ for (QueryPhaseCollector collector : collectors ) {
435+ topDocsCollectors .add (collector .topDocsCollector );
436+ aggsCollectors .add (collector .aggsCollector );
437+ if (collector .isTerminatedAfter ()) {
438+ terminatedAfter = true ;
439+ }
440+ }
441+ @ SuppressWarnings ("unchecked" )
442+ org .apache .lucene .search .CollectorManager <Collector , Void > topDocsManager = (org .apache .lucene .search .CollectorManager <
443+ Collector ,
444+ Void >) topDocsCollectorManager ;
445+ topDocsManager .reduce (topDocsCollectors );
446+ if (aggsCollectorManager != null ) {
447+ @ SuppressWarnings ("unchecked" )
448+ org .apache .lucene .search .CollectorManager <Collector , Void > aggsManager = (org .apache .lucene .search .CollectorManager <
449+ Collector ,
450+ Void >) aggsCollectorManager ;
451+ aggsManager .reduce (aggsCollectors );
452+ }
453+ return null ;
454+ }
455+
456+ boolean isTerminatedAfter () {
457+ return terminatedAfter ;
458+ }
459+ }
323460}
0 commit comments