Skip to content

Commit d854b1c

Browse files
authored
Bugfix: fixed scroll with knn query (#126035)
Although scrolling is not recommended for knn queries, it is effective. But I found a bug that when use scroll in the knn query, the But I found a bug that when using scroll in knn query, knn_score_doc will be lost in query phase, which means knn query does not work. In addition, the operations for directly querying the node where the shard is located and querying the node with transport are different. It can be reproduced on the local node. Because the query phase uses the previous ShardSearchRequest object stored before the dfs phase. But when it run in the local node, it don't do the encode and decode processso the operation is correct. I wrote an IT to reproduce it and fixed it by adding the new source to the LegacyReaderContext.
1 parent c662590 commit d854b1c

File tree

3 files changed

+147
-0
lines changed

3 files changed

+147
-0
lines changed

docs/changelog/126035.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126035
2+
summary: Fix top level knn search with scroll
3+
area: Vector Search
4+
type: bug
5+
issues: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search;
11+
12+
import org.elasticsearch.action.search.SearchRequest;
13+
import org.elasticsearch.action.search.SearchResponse;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.index.query.QueryBuilders;
18+
import org.elasticsearch.search.builder.SearchSourceBuilder;
19+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
20+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
21+
import org.elasticsearch.test.ESIntegTestCase;
22+
import org.elasticsearch.xcontent.XContentBuilder;
23+
import org.elasticsearch.xcontent.XContentFactory;
24+
25+
import java.util.List;
26+
27+
import static org.hamcrest.Matchers.notNullValue;
28+
29+
@ESIntegTestCase.ClusterScope(minNumDataNodes = 2)
30+
public class KnnSearchIT extends ESIntegTestCase {
31+
32+
private static final String INDEX_NAME = "test_knn_index";
33+
private static final String VECTOR_FIELD = "vector";
34+
35+
private XContentBuilder createKnnMapping() throws Exception {
36+
return XContentFactory.jsonBuilder()
37+
.startObject()
38+
.startObject("properties")
39+
.startObject(VECTOR_FIELD)
40+
.field("type", "dense_vector")
41+
.field("dims", 2)
42+
.field("index", true)
43+
.field("similarity", "l2_norm")
44+
.startObject("index_options")
45+
.field("type", "hnsw")
46+
.endObject()
47+
.endObject()
48+
.startObject("category")
49+
.field("type", "keyword")
50+
.endObject()
51+
.endObject()
52+
.endObject();
53+
}
54+
55+
public void testKnnSearchWithScroll() throws Exception {
56+
final int numShards = randomIntBetween(1, 3);
57+
Client client = client();
58+
client.admin()
59+
.indices()
60+
.prepareCreate(INDEX_NAME)
61+
.setSettings(Settings.builder().put("index.number_of_shards", numShards))
62+
.setMapping(createKnnMapping())
63+
.get();
64+
65+
final int count = 100;
66+
for (int i = 0; i < count; i++) {
67+
XContentBuilder source = XContentFactory.jsonBuilder()
68+
.startObject()
69+
.field(VECTOR_FIELD, new float[] { i * 0.1f, i * 0.1f })
70+
.field("category", i >= 90 ? "last_ten" : null)
71+
.endObject();
72+
client.prepareIndex(INDEX_NAME).setSource(source).get();
73+
}
74+
refresh(INDEX_NAME);
75+
76+
final int k = randomIntBetween(11, 15);
77+
// test top level knn search
78+
{
79+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
80+
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)));
81+
executeScrollSearch(client, sourceBuilder, k);
82+
}
83+
// test top level knn search + another query
84+
{
85+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
86+
sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)));
87+
sourceBuilder.query(QueryBuilders.existsQuery("category").boost(10));
88+
executeScrollSearch(client, sourceBuilder, k + 10);
89+
}
90+
91+
// test knn query
92+
{
93+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
94+
sourceBuilder.query(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null));
95+
executeScrollSearch(client, sourceBuilder, k * numShards);
96+
}
97+
// test knn query + another query
98+
{
99+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
100+
sourceBuilder.query(
101+
QueryBuilders.boolQuery()
102+
.should(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null))
103+
.should(QueryBuilders.existsQuery("category").boost(10))
104+
);
105+
executeScrollSearch(client, sourceBuilder, k * numShards + 10);
106+
}
107+
108+
}
109+
110+
private static void executeScrollSearch(Client client, SearchSourceBuilder sourceBuilder, int expectedNumHits) {
111+
SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
112+
searchRequest.source(sourceBuilder).scroll(TimeValue.timeValueMinutes(1));
113+
114+
SearchResponse searchResponse = client.search(searchRequest).actionGet();
115+
int hitsCollected = 0;
116+
float prevScore = Float.POSITIVE_INFINITY;
117+
try {
118+
do {
119+
assertThat(searchResponse.getScrollId(), notNullValue());
120+
assertEquals(expectedNumHits, searchResponse.getHits().getTotalHits().value());
121+
// assert correct order of returned hits
122+
for (var searchHit : searchResponse.getHits()) {
123+
assert (searchHit.getScore() <= prevScore);
124+
prevScore = searchHit.getScore();
125+
hitsCollected += 1;
126+
}
127+
searchResponse.decRef();
128+
searchResponse = client().prepareSearchScroll(searchResponse.getScrollId()).setScroll(TimeValue.timeValueMinutes(1)).get();
129+
} while (searchResponse.getHits().getHits().length > 0);
130+
} finally {
131+
assertEquals(expectedNumHits, hitsCollected);
132+
clearScroll(searchResponse.getScrollId());
133+
searchResponse.decRef();
134+
}
135+
}
136+
137+
}

server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java

+5
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ public Engine.Searcher acquireSearcher(String source) {
7272

7373
@Override
7474
public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) {
75+
if (other != null) {
76+
// The top level knn search modifies the source after the DFS phase.
77+
// so we need to update the source stored in the context.
78+
shardSearchRequest.source(other.source());
79+
}
7580
return shardSearchRequest;
7681
}
7782

0 commit comments

Comments
 (0)