From 2c1f6869d023d9b07a880a9b61a0cf5103c1bef0 Mon Sep 17 00:00:00 2001 From: weizijun Date: Tue, 18 Mar 2025 15:10:26 +0800 Subject: [PATCH 1/5] add index.max_knn_num_candidates settings --- .../test/search.vectors/40_knn_search.yml | 29 ++++++++++++ .../common/settings/IndexScopedSettings.java | 1 + .../elasticsearch/index/IndexSettings.java | 23 ++++++++++ .../elasticsearch/search/dfs/DfsPhase.java | 12 ++++- .../search/vectors/KnnSearchBuilder.java | 8 +--- .../vectors/KnnSearchRequestParser.java | 4 -- .../search/vectors/KnnVectorQueryBuilder.java | 6 +-- .../search/dfs/DfsPhaseTests.java | 44 +++++++++++++++++++ .../search/vectors/KnnSearchBuilderTests.java | 8 ---- .../vectors/KnnSearchRequestParserTests.java | 16 ------- 10 files changed, 110 insertions(+), 41 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index 8f846dd76721d..aa6ef7a9ea8c4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -605,3 +605,32 @@ setup: - match: { hits.hits.0._score: $knn_score0 } - match: { hits.hits.1._score: $knn_score1 } - match: { hits.hits.2._score: $knn_score2 } + +--- +"kNN search with num_candidates exceeds max allowed value": + - requires: + reason: 'num_candidates exceeds max allowed value' + test_runner_features: [capabilities] + + - do: + indices.create: + index: test_num_candidates + body: + mappings: + properties: + vector: + type: dense_vector + element_type: float + dims: 5 + settings: + index.max_knn_num_candidates: 100 + - do: + catch: /\[num_candidates\] cannot exceed \[100\]/ + search: + index: test_num_candidates + body: + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 200 diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index 422f4018dd69e..ae68fa30aada3 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -141,6 +141,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IndexSettings.MAX_REFRESH_LISTENERS_PER_SHARD, IndexSettings.MAX_SLICES_PER_SCROLL, IndexSettings.MAX_REGEX_LENGTH_SETTING, + IndexSettings.INDEX_MAX_KNN_NUM_CANDIDATES_SETTING, ShardsLimitAllocationDecider.INDEX_TOTAL_SHARDS_PER_NODE_SETTING, IndexSettings.INDEX_GC_DELETES_SETTING, IndexSettings.INDEX_SOFT_DELETES_SETTING, diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index e1002995559bc..b2350765fdb9d 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -284,6 +284,17 @@ public final class IndexSettings { Property.IndexScope ); + /** + * The maximum number of candidates to be considered for KNN search. The default value is 10_000. + */ + public static final Setting INDEX_MAX_KNN_NUM_CANDIDATES_SETTING = Setting.intSetting( + "index.max_knn_num_candidates", + 10_000, + 1, + Property.Dynamic, + Property.IndexScope + ); + public static final TimeValue DEFAULT_REFRESH_INTERVAL = new TimeValue(1, TimeUnit.SECONDS); public static final Setting NODE_DEFAULT_REFRESH_INTERVAL_SETTING = Setting.timeSetting( "node._internal.default_refresh_interval", @@ -930,6 +941,8 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) { */ private volatile int maxRegexLength; + private volatile int maxKnnNumCandidates; + private final IndexRouting indexRouting; /** @@ -1083,6 +1096,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti mappingDepthLimit = scopedSettings.get(INDEX_MAPPING_DEPTH_LIMIT_SETTING); mappingFieldNameLengthLimit = scopedSettings.get(INDEX_MAPPING_FIELD_NAME_LENGTH_LIMIT_SETTING); mappingDimensionFieldsLimit = scopedSettings.get(INDEX_MAPPING_DIMENSION_FIELDS_LIMIT_SETTING); + maxKnnNumCandidates = scopedSettings.get(INDEX_MAX_KNN_NUM_CANDIDATES_SETTING); indexRouting = IndexRouting.fromIndexMetadata(indexMetadata); sourceKeepMode = scopedSettings.get(Mapper.SYNTHETIC_SOURCE_KEEP_INDEX_SETTING); es87TSDBCodecEnabled = scopedSettings.get(TIME_SERIES_ES87TSDB_CODEC_ENABLED_SETTING); @@ -1203,6 +1217,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti this::setSkipIgnoredSourceWrite ); scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead); + scopedSettings.addSettingsUpdateConsumer(INDEX_MAX_KNN_NUM_CANDIDATES_SETTING, this::setMaxKnnNumCandidates); } private void setSearchIdleAfter(TimeValue searchIdleAfter) { @@ -1821,4 +1836,12 @@ public TimestampBounds getTimestampBounds() { public IndexRouting getIndexRouting() { return indexRouting; } + + public int getMaxKnnNumCandidates() { + return maxKnnNumCandidates; + } + + public void setMaxKnnNumCandidates(int maxKnnNumCandidates) { + this.maxKnnNumCandidates = maxKnnNumCandidates; + } } diff --git a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index 6a99b51ac679c..693efab186c65 100644 --- a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -44,6 +44,7 @@ import java.util.Map; import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; /** * DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase. @@ -177,7 +178,7 @@ private static Timer maybeStartTimer(DfsProfiler profiler, DfsTimingType dtt) { return null; }; - private static void executeKnnVectorQuery(SearchContext context) throws IOException { + static void executeKnnVectorQuery(SearchContext context) throws IOException { SearchSourceBuilder source = context.request().source(); if (source == null || source.knnSearch().isEmpty()) { return; @@ -186,6 +187,15 @@ private static void executeKnnVectorQuery(SearchContext context) throws IOExcept SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext(); List knnSearch = source.knnSearch(); List knnVectorQueryBuilders = knnSearch.stream().map(KnnSearchBuilder::toQueryBuilder).toList(); + int maxKnnNumCandidates = context.indexShard().indexSettings().getMaxKnnNumCandidates(); + for (KnnVectorQueryBuilder knnVectorQueryBuilder : knnVectorQueryBuilders) { + if (knnVectorQueryBuilder.numCands() != null && knnVectorQueryBuilder.numCands() > maxKnnNumCandidates) { + throw new IllegalArgumentException( + "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]" + ); + } + } + // Since we apply boost during the DfsQueryPhase, we should not apply boost here: knnVectorQueryBuilders.forEach(knnVectorQueryBuilder -> knnVectorQueryBuilder.boost(DEFAULT_BOOST)); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 9b9718efcf523..49191338729ee 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -43,7 +43,6 @@ * Defines a kNN search to run in the search request. */ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewriteable { - public static final int NUM_CANDS_LIMIT = 10_000; public static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f; public static final ParseField FIELD_FIELD = new ParseField("field"); @@ -264,9 +263,6 @@ private KnnSearchBuilder( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]" ); } - if (numCandidates > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } if (queryVector == null && queryVectorBuilder == null) { throw new IllegalArgumentException( format( @@ -667,9 +663,7 @@ public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) { public KnnSearchBuilder build(int size) { int requestSize = size < 0 ? DEFAULT_SIZE : size; int adjustedK = k == null ? requestSize : k; - int adjustedNumCandidates = numCandidates == null - ? Math.round(Math.min(NUM_CANDS_LIMIT, NUM_CANDS_MULTIPLICATIVE_FACTOR * adjustedK)) - : numCandidates; + int adjustedNumCandidates = numCandidates == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * adjustedK) : numCandidates; return new KnnSearchBuilder( field, queryVectorBuilder, diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index 12573d5ad496e..c82fc135c6e8f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -195,7 +195,6 @@ public void toSearchRequest(SearchRequestBuilder builder) { // visible for testing static class KnnSearch { - private static final int NUM_CANDS_LIMIT = 10000; static final ParseField FIELD_FIELD = new ParseField("field"); static final ParseField K_FIELD = new ParseField("k"); static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates"); @@ -253,9 +252,6 @@ public KnnVectorQueryBuilder toQueryBuilder() { "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than " + "[" + K_FIELD.getPreferredName() + "]" ); } - if (numCands > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } return new KnnVectorQueryBuilder(field, queryVector, numCands, numCands, null, null); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 565fd7325a5ac..f0a05cc71149e 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -56,7 +56,6 @@ */ public class KnnVectorQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "knn"; - private static final int NUM_CANDS_LIMIT = 10_000; private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f; public static final ParseField FIELD_FIELD = new ParseField("field"); @@ -183,9 +182,6 @@ private KnnVectorQueryBuilder( if (k != null && k < 1) { throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0"); } - if (numCands != null && numCands > NUM_CANDS_LIMIT) { - throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); - } if (k != null && numCands != null && numCands < k) { throw new IllegalArgumentException( "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]" @@ -496,7 +492,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { k = Math.min(k, numCands); } } - int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, NUM_CANDS_LIMIT)) : numCands; + int adjustedNumCands = numCands == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * k) : numCands; if (fieldType == null) { return new MatchNoDocsQuery(); } diff --git a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java index d28bb98547cec..b40c4c0ed78e1 100644 --- a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java @@ -17,12 +17,20 @@ import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.ContextIndexSearcher; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult; import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.QueryProfileShardResult; +import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.IndexSettingsModule; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; @@ -32,6 +40,9 @@ import java.util.List; import java.util.concurrent.ThreadPoolExecutor; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class DfsPhaseTests extends ESTestCase { ThreadPoolExecutor threadPoolExecutor; @@ -102,4 +113,37 @@ public void testSingleKnnSearch() throws IOException { reader.close(); } } + + public void testNumCandidatesExceedsMax() { + Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build(); + IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings); + + SearchContext context = mock(SearchContext.class); + when(context.indexShard()).thenAnswer(invocation -> { + IndexShard mockIndexShard = mock(IndexShard.class); + when(mockIndexShard.indexSettings()).thenReturn(indexSettings); + return mockIndexShard; + }); + + // 构造超过最大值的查询参数 + KnnSearchBuilder queryBuilder = new KnnSearchBuilder( + "float_vector", + new float[] { 0, 0, 0 }, + 10, + 150, // 超过maxKnnNumCandidates的值 + null, + null + ); + SearchSourceBuilder source = new SearchSourceBuilder(); + source.knnSearch(List.of(queryBuilder)); + ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); + when(searchRequest.source()).thenReturn(source); + when(context.request()).thenReturn(searchRequest); + + // 验证异常抛出 + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> DfsPhase.executeKnnVectorQuery(context)); + assertEquals("[num_candidates] cannot exceed [100]", e.getMessage()); + + } + } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 8cca3f9ed8a21..a85e9a933b288 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -238,14 +238,6 @@ public void testNumCandsLessThanK() { assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } - public void testNumCandsExceedsLimit() { - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) - ); - assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); - } - public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java index 4e4d2158a9574..bd1d192283f2c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchRequestParserTests.java @@ -179,22 +179,6 @@ public void testNumCandsLessThanK() throws IOException { assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } - public void testNumCandsExceedsLimit() throws IOException { - XContentType xContentType = randomFrom(XContentType.values()); - XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) - .startObject() - .startObject(KnnSearchRequestParser.KNN_SECTION_FIELD.getPreferredName()) - .field(KnnSearch.FIELD_FIELD.getPreferredName(), "field") - .field(KnnSearch.K_FIELD.getPreferredName(), 100) - .field(KnnSearch.NUM_CANDS_FIELD.getPreferredName(), 10002) - .field(KnnSearch.QUERY_VECTOR_FIELD.getPreferredName(), new float[] { 1.0f, 2.0f, 3.0f }) - .endObject() - .endObject(); - - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> parseSearchRequest(builder)); - assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); - } - public void testInvalidK() throws IOException { XContentType xContentType = randomFrom(XContentType.values()); XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()) From 1924f9c6a37987db2258464c0c8d1ad5a0f01008 Mon Sep 17 00:00:00 2001 From: weizijun Date: Tue, 18 Mar 2025 15:17:02 +0800 Subject: [PATCH 2/5] improve --- .../test/search.vectors/40_knn_search.yml | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index aa6ef7a9ea8c4..daed4f42a8781 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -623,7 +623,26 @@ setup: element_type: float dims: 5 settings: - index.max_knn_num_candidates: 100 + index.max_knn_num_candidates: 500 + + - do: + search: + index: test_num_candidates + body: + knn: + field: vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 200 + + - match: { hits.total.value: 0 } + + - do: + indices.put_settings: + index: test_num_candidates + body: + index.max_knn_num_candidates: 100 + - do: catch: /\[num_candidates\] cannot exceed \[100\]/ search: From 0573a3f5a22146570a22d8aa5803eb08776893bf Mon Sep 17 00:00:00 2001 From: weizijun Date: Wed, 19 Mar 2025 00:24:57 +0800 Subject: [PATCH 3/5] fixup --- .../test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java index b40c4c0ed78e1..b5de85ba0abd9 100644 --- a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java @@ -125,7 +125,6 @@ public void testNumCandidatesExceedsMax() { return mockIndexShard; }); - // 构造超过最大值的查询参数 KnnSearchBuilder queryBuilder = new KnnSearchBuilder( "float_vector", new float[] { 0, 0, 0 }, @@ -140,7 +139,6 @@ public void testNumCandidatesExceedsMax() { when(searchRequest.source()).thenReturn(source); when(context.request()).thenReturn(searchRequest); - // 验证异常抛出 IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> DfsPhase.executeKnnVectorQuery(context)); assertEquals("[num_candidates] cannot exceed [100]", e.getMessage()); From 6f772613eb77a50e35a8720141349ce9fc9ff2b9 Mon Sep 17 00:00:00 2001 From: weizijun Date: Wed, 19 Mar 2025 17:03:56 +0800 Subject: [PATCH 4/5] replace the check from dfs phase to KnnVectorQueryBuilder doToQuery --- .../elasticsearch/search/dfs/DfsPhase.java | 12 +----- .../search/vectors/KnnVectorQueryBuilder.java | 11 ++++- .../search/dfs/DfsPhaseTests.java | 42 ------------------- ...AbstractKnnVectorQueryBuilderTestCase.java | 18 ++++++++ 4 files changed, 28 insertions(+), 55 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index 693efab186c65..6a99b51ac679c 100644 --- a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -44,7 +44,6 @@ import java.util.Map; import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; -import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; /** * DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase. @@ -178,7 +177,7 @@ private static Timer maybeStartTimer(DfsProfiler profiler, DfsTimingType dtt) { return null; }; - static void executeKnnVectorQuery(SearchContext context) throws IOException { + private static void executeKnnVectorQuery(SearchContext context) throws IOException { SearchSourceBuilder source = context.request().source(); if (source == null || source.knnSearch().isEmpty()) { return; @@ -187,15 +186,6 @@ static void executeKnnVectorQuery(SearchContext context) throws IOException { SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext(); List knnSearch = source.knnSearch(); List knnVectorQueryBuilders = knnSearch.stream().map(KnnSearchBuilder::toQueryBuilder).toList(); - int maxKnnNumCandidates = context.indexShard().indexSettings().getMaxKnnNumCandidates(); - for (KnnVectorQueryBuilder knnVectorQueryBuilder : knnVectorQueryBuilders) { - if (knnVectorQueryBuilder.numCands() != null && knnVectorQueryBuilder.numCands() > maxKnnNumCandidates) { - throw new IllegalArgumentException( - "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]" - ); - } - } - // Since we apply boost during the DfsQueryPhase, we should not apply boost here: knnVectorQueryBuilders.forEach(knnVectorQueryBuilder -> knnVectorQueryBuilder.boost(DEFAULT_BOOST)); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index f0a05cc71149e..725d2ee202d7b 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -482,7 +482,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType fieldType = context.getFieldType(fieldName); int k; if (this.k != null) { k = this.k; @@ -492,7 +491,15 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { k = Math.min(k, numCands); } } - int adjustedNumCands = numCands == null ? Math.round(NUM_CANDS_MULTIPLICATIVE_FACTOR * k) : numCands; + + int maxKnnNumCandidates = context.getIndexSettings().getMaxKnnNumCandidates(); + if (numCands != null && numCands > maxKnnNumCandidates) { + throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + maxKnnNumCandidates + "]"); + } + + int adjustedNumCands = numCands == null ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * k, maxKnnNumCandidates)) : numCands; + + MappedFieldType fieldType = context.getFieldType(fieldName); if (fieldType == null) { return new MatchNoDocsQuery(); } diff --git a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java index b5de85ba0abd9..d28bb98547cec 100644 --- a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java @@ -17,20 +17,12 @@ import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.shard.IndexShard; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.ContextIndexSearcher; -import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult; import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.QueryProfileShardResult; -import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.IndexSettingsModule; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; @@ -40,9 +32,6 @@ import java.util.List; import java.util.concurrent.ThreadPoolExecutor; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - public class DfsPhaseTests extends ESTestCase { ThreadPoolExecutor threadPoolExecutor; @@ -113,35 +102,4 @@ public void testSingleKnnSearch() throws IOException { reader.close(); } } - - public void testNumCandidatesExceedsMax() { - Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build(); - IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings); - - SearchContext context = mock(SearchContext.class); - when(context.indexShard()).thenAnswer(invocation -> { - IndexShard mockIndexShard = mock(IndexShard.class); - when(mockIndexShard.indexSettings()).thenReturn(indexSettings); - return mockIndexShard; - }); - - KnnSearchBuilder queryBuilder = new KnnSearchBuilder( - "float_vector", - new float[] { 0, 0, 0 }, - 10, - 150, // 超过maxKnnNumCandidates的值 - null, - null - ); - SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(queryBuilder)); - ShardSearchRequest searchRequest = mock(ShardSearchRequest.class); - when(searchRequest.source()).thenReturn(source); - when(context.request()).thenReturn(searchRequest); - - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> DfsPhase.executeKnnVectorQuery(context)); - assertEquals("[num_candidates] cannot exceed [100]", e.getMessage()); - - } - } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index b3764d528ff0f..e811364141928 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -21,6 +21,8 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.query.InnerHitsRewriteContext; @@ -34,6 +36,7 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.test.AbstractBuilderTestCase; import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.test.IndexSettingsModule; import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -53,6 +56,8 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; @@ -458,4 +463,17 @@ public void testRewriteWithQueryVectorBuilder() throws Exception { assertThat(rewritten.filterQueries(), hasSize(numFilters)); assertThat(rewritten.filterQueries(), equalTo(filters)); } + + public void testMaxNumCandidatesExceeded() { + Settings settings = Settings.builder().put("index.max_knn_num_candidates", 100).build(); + IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", settings); + + SearchExecutionContext context = mock(SearchExecutionContext.class); + when(context.getIndexSettings()).thenReturn(indexSettings); + + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 150, null, null); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); + assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [100]")); + } } From 6a7c414c8beef2092198eb0d774f410b458a21d9 Mon Sep 17 00:00:00 2001 From: weizijun Date: Wed, 19 Mar 2025 17:24:08 +0800 Subject: [PATCH 5/5] add docs --- docs/reference/elasticsearch/index-settings/index-modules.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/reference/elasticsearch/index-settings/index-modules.md b/docs/reference/elasticsearch/index-settings/index-modules.md index b099313953671..5d1f9e253a944 100644 --- a/docs/reference/elasticsearch/index-settings/index-modules.md +++ b/docs/reference/elasticsearch/index-settings/index-modules.md @@ -205,6 +205,9 @@ $$$index-max-regex-length$$$ `index.max_regex_length` : The maximum length of value that can be used in `regexp` or `prefix` query. Defaults to `1000`. +`index.max_knn_num_candidates` +: The maximum number of candidates that can be used in KNN Query. Defaults to `10000`. + $$$index-query-default-field$$$ `index.query.default_field`