Skip to content

Adding MinScore support to Linear Retriever #124182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a41dac9
MinScore implementation in Linear retriever
mridula-s109 Mar 6, 2025
61dd8df
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 6, 2025
6be41de
Merge remote-tracking branch 'origin/minscore-linear' into minscore-l…
mridula-s109 Mar 12, 2025
a52628c
Resolving PR comments
mridula-s109 Mar 13, 2025
a225c01
Fixed PR comments, added yaml and made changes to the markdown
mridula-s109 Mar 13, 2025
429a620
Merge branch 'main' into minscore-linear
mridula-s109 Mar 13, 2025
95710e6
Update docs/changelog/124182.yaml
mridula-s109 Mar 13, 2025
a32f947
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 13, 2025
a424e77
Resolved on the PR comments
mridula-s109 Mar 20, 2025
38a9b50
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 21, 2025
810d151
Added changes wrt to yaml testing from PR comments
mridula-s109 Mar 25, 2025
88703fd
Worked on kathleen comments first half
mridula-s109 Mar 26, 2025
d9b44e2
Reverted the integration test in line with the main branch
mridula-s109 Mar 26, 2025
46a1b94
Resolved comments in the PR and its in compiling state
mridula-s109 Mar 26, 2025
73a9bad
Unit tests passing
mridula-s109 Mar 26, 2025
6f93d3d
[CI] Auto commit changes from spotless
elasticsearchmachine Mar 26, 2025
1fd4a22
reverted inclusion of pit in retrierver it file
mridula-s109 Mar 26, 2025
f53cd0a
Removed transport versions
mridula-s109 Mar 26, 2025
9be3783
Modified rrfRank doc to the way main was
mridula-s109 Mar 27, 2025
7e2f732
Committing the changes done until now, will be doing a clean commit next
mridula-s109 Apr 2, 2025
3feaa3f
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
MinScore implementation in Linear retriever
  • Loading branch information
mridula-s109 committed Mar 6, 2025
commit a41dac95a78e7cad4b9ca6e36995e6ca3e1f538f
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
private final RankDoc[] rankDocs;
private final QueryBuilder[] queryBuilders;
private final boolean onlyRankDocs;
private final float minScore;

public RankDocsQueryBuilder(RankDoc[] rankDocs, QueryBuilder[] queryBuilders, boolean onlyRankDocs) {
public RankDocsQueryBuilder(RankDoc[] rankDocs, QueryBuilder[] queryBuilders, boolean onlyRankDocs, float minScore) {
this.rankDocs = rankDocs;
this.queryBuilders = queryBuilders;
this.onlyRankDocs = onlyRankDocs;
this.minScore = minScore;
}

public RankDocsQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -45,9 +47,11 @@ public RankDocsQueryBuilder(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
this.queryBuilders = in.readOptionalArray(c -> c.readNamedWriteable(QueryBuilder.class), QueryBuilder[]::new);
this.onlyRankDocs = in.readBoolean();
this.minScore = in.readFloat();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is written, it will break transport serialization.

What you want to do, is register a new TransportVersion for your change in TransportVersions.java. You'll then reference this new transport version to serialize the min score - older transport versions will always just use a default value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If i want to include min_score without transport versioning changes, then is the way it is currently written fine?

} else {
this.queryBuilders = null;
this.onlyRankDocs = false;
this.minScore = Float.MIN_VALUE;
}
}

Expand All @@ -70,7 +74,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
changed |= newQueryBuilders[i] != queryBuilders[i];
}
if (changed) {
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs, minScore);
clone.queryName(queryName());
return clone;
}
Expand All @@ -88,6 +92,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
out.writeOptionalArray(StreamOutput::writeNamedWriteable, queryBuilders);
out.writeBoolean(onlyRankDocs);
out.writeFloat(minScore);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to reference the new transport version you create here as well, to ensure that serialization is consistent.

}
}

Expand Down Expand Up @@ -115,7 +120,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
queries = new Query[0];
queryNames = Strings.EMPTY_ARRAY;
}
return new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs);
return new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs, minScore);
}

@Override
Expand All @@ -135,12 +140,13 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
protected boolean doEquals(RankDocsQueryBuilder other) {
return Arrays.equals(rankDocs, other.rankDocs)
&& Arrays.equals(queryBuilders, other.queryBuilders)
&& onlyRankDocs == other.onlyRankDocs;
&& onlyRankDocs == other.onlyRankDocs
&& minScore == other.minScore;
}

@Override
protected int doHashCode() {
return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs);
return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs, minScore);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
public QueryBuilder topDocsQuery() {
assert queryVector != null : "query vector must be materialized at this point";
assert rankDocs != null : "rankDocs should have been materialized by now";
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true);
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true, Float.MIN_VALUE);
if (preFilterQueryBuilders.isEmpty()) {
return rankDocsQuery.queryName(retrieverName);
}
Expand All @@ -217,7 +217,8 @@ public QueryBuilder explainQuery() {
var rankDocsQuery = new RankDocsQueryBuilder(
rankDocs,
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) },
true
false,
Float.MIN_VALUE
);
if (preFilterQueryBuilders.isEmpty()) {
return rankDocsQuery.queryName(retrieverName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ public QueryBuilder explainQuery() {
var explainQuery = new RankDocsQueryBuilder(
rankDocs.get(),
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
true
true,
Float.MIN_VALUE
);
explainQuery.queryName(retrieverName());
return explainQuery;
Expand All @@ -113,17 +114,19 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
rankQuery = new RankDocsQueryBuilder(
rankDocResults,
sources.stream().map(RetrieverBuilder::topDocsQuery).toArray(QueryBuilder[]::new),
false
false,
Float.MIN_VALUE
);
} else {
rankQuery = new RankDocsQueryBuilder(
rankDocResults,
sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new),
false
false,
Float.MIN_VALUE
);
}
} else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false, Float.MIN_VALUE);
}
rankQuery.queryName(retrieverName());
// ignore prefilters of this level, they were already propagated to children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,7 @@ public float getMaxScore(int docId) {
}

@Override
public float score() {
// We could still end up with a valid 0 score for a RankDoc
// so here we want to differentiate between this and all the tailQuery matches
// that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for
// RankDoc matches.
public float score() throws IOException {
return Math.max(docs[upTo].score, Float.MIN_VALUE);
}

Expand Down Expand Up @@ -234,6 +230,7 @@ public int hashCode() {
// RankDocs provided. This query does not contribute to scoring, as it is set as filter when creating the weight
private final Query tailQuery;
private final boolean onlyRankDocs;
private final float minScore;

/**
* Creates a {@code RankDocsQuery} based on the provided docs.
Expand All @@ -242,8 +239,9 @@ public int hashCode() {
* @param sources The original queries that were used to compute the top documents
* @param queryNames The names (if present) of the original retrievers
* @param onlyRankDocs Whether the query should only match the provided rank docs
* @param minScore The minimum score threshold for documents to be included in total hits
*/
public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, String[] queryNames, boolean onlyRankDocs) {
public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, String[] queryNames, boolean onlyRankDocs, float minScore) {
assert sources.length == queryNames.length;
// clone to avoid side-effect after sorting
this.docs = rankDocs.clone();
Expand All @@ -260,13 +258,15 @@ public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, St
this.tailQuery = null;
}
this.onlyRankDocs = onlyRankDocs;
this.minScore = minScore;
}

private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs) {
this.docs = docs;
this.topQuery = topQuery;
this.tailQuery = tailQuery;
this.onlyRankDocs = onlyRankDocs;
this.minScore = Float.MIN_VALUE;
}

private static int binarySearch(RankDoc[] docs, int fromIndex, int toIndex, int key) {
Expand Down Expand Up @@ -346,7 +346,41 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException {

@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
return combinedWeight.scorerSupplier(context);
return new ScorerSupplier() {
private final ScorerSupplier supplier = combinedWeight.scorerSupplier(context);

@Override
public Scorer get(long leadCost) throws IOException {
Scorer scorer = supplier.get(leadCost);
return new Scorer() {
@Override
public DocIdSetIterator iterator() {
return scorer.iterator();
}

@Override
public float getMaxScore(int docId) throws IOException {
return scorer.getMaxScore(docId);
}

@Override
public float score() throws IOException {
float score = scorer.score();
return score >= minScore ? score : 0f;
}

@Override
public int docID() {
return scorer.docID();
}
};
}

@Override
public long cost() {
return supplier.cost();
}
};
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.search.rank.RankDoc;
Expand All @@ -31,6 +32,7 @@
import java.util.Random;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class RankDocsQueryBuilderTests extends AbstractQueryTestCase<RankDocsQueryBuilder> {
Expand All @@ -50,14 +52,30 @@ private RankDoc[] generateRandomRankDocs() {
@Override
protected RankDocsQueryBuilder doCreateTestQueryBuilder() {
RankDoc[] rankDocs = generateRandomRankDocs();
return new RankDocsQueryBuilder(rankDocs, null, false);
return new RankDocsQueryBuilder(rankDocs, null, false, Float.MIN_VALUE);
}

@Override
protected void doAssertLuceneQuery(RankDocsQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
assertTrue(query instanceof RankDocsQuery);
assertThat(query, instanceOf(RankDocsQuery.class));
RankDocsQuery rankDocsQuery = (RankDocsQuery) query;
assertArrayEquals(queryBuilder.rankDocs(), rankDocsQuery.rankDocs());
assertThat(rankDocsQuery.rankDocs(), equalTo(queryBuilder.rankDocs()));
}

protected Query createTestQuery() throws IOException {
return createRandomQuery().toQuery(createSearchExecutionContext());
}

private RankDocsQueryBuilder createQueryBuilder() {
return createRandomQuery();
}

private RankDocsQueryBuilder createRandomQuery() {
RankDoc[] rankDocs = new RankDoc[randomIntBetween(1, 5)];
for (int i = 0; i < rankDocs.length; i++) {
rankDocs[i] = new RankDoc(randomInt(), randomFloat(), randomIntBetween(0, 2));
}
return new RankDocsQueryBuilder(rankDocs, null, randomBoolean(), Float.MIN_VALUE);
}

/**
Expand Down Expand Up @@ -151,7 +169,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
rankDocs,
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
new String[1],
false
false,
Float.MIN_VALUE
);
var topDocsManager = new TopScoreDocCollectorManager(topSize, null, totalHitsThreshold);
var col = searcher.search(q, topDocsManager);
Expand All @@ -172,7 +191,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
rankDocs,
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
new String[1],
false
false,
Float.MIN_VALUE
);
var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE);
var col = searcher.search(q, topDocsManager);
Expand All @@ -187,7 +207,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
rankDocs,
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
new String[1],
true
true,
Float.MIN_VALUE
);
var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE);
var col = searcher.search(q, topDocsManager);
Expand All @@ -204,7 +225,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException {
singleRankDoc,
new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) },
new String[1],
false
false,
Float.MIN_VALUE
);
var topDocsManager = new TopScoreDocCollectorManager(1, null, 0);
var col = searcher.search(q, topDocsManager);
Expand Down Expand Up @@ -257,10 +279,29 @@ public void shouldThrowForNegativeScores() throws IOException {
iw.addDocument(new Document());
try (IndexReader reader = iw.getReader()) {
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false, Float.MIN_VALUE);
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
}
}
}

public void testCreateQuery() throws IOException {
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
iw.addDocument(new Document());
try (IndexReader reader = iw.getReader()) {
RankDoc[] rankDocs = new RankDoc[] { new RankDoc(0, randomFloat(), 0) };
RankDocsQuery q = new RankDocsQuery(
reader,
rankDocs,
new Query[] { new MatchAllDocsQuery() },
new String[] { "test" },
false,
Float.MIN_VALUE
);
assertNotNull(q);
assertArrayEquals(rankDocs, q.rankDocs());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void testRankDocSerialization() throws IOException {
for (int i = 0; i < totalDocs; i++) {
docs.add(createTestRankDoc());
}
RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean());
RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean(), Float.MIN_VALUE);
RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class);
assertThat(rankDocsQueryBuilder, equalTo(copy));
}
Expand Down
Loading