Skip to content

Enable concurrent intra merge for HNSW graphs #108164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: lucene_snapshot
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ public ES814HnswScalarQuantizedVectorsFormat(
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers > 1 && mergeExec == null) {
throw new IllegalArgumentException("No executor service passed in when " + numMergeWorkers + " merge workers are requested");
}
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Collections;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.Executor;

/**
* An extension to the {@link ConcurrentMergeScheduler} that provides tracking on merge times, total
Expand Down Expand Up @@ -78,6 +79,17 @@ protected boolean verbose() {
return super.verbose();
}

// TODO: this is temporarily, remove this override and enable multithreaded merges for all kind of merges
@Override
public Executor getIntraMergeExecutor(MergePolicy.OneMerge merge) {
// Enable multithreaded merges only for force merge operations
if (merge.getStoreMergeInfo().mergeMaxNumSegments != -1) {
return super.getIntraMergeExecutor(merge);
} else {
return null;
}
}

Comment on lines +82 to +92
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think this should return null. It should instead return a same thread executor.

@Override
/** Overridden to route specific MergeThread messages to our logger. */
protected void message(String message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ private static void postProcessDynamicArrayMapping(DocumentParserContext context

DenseVectorFieldMapper.Builder builder = new DenseVectorFieldMapper.Builder(
fieldName,
context.indexSettings().getIndexVersionCreated()
context.indexSettings().getIndexVersionCreated(),
context.indexSettings().getMergeSchedulerConfig().getMaxThreadCount()
);
DenseVectorFieldMapper denseVectorFieldMapper = builder.build(builderContext);
context.updateDynamicMappers(fullFieldName, List.of(denseVectorFieldMapper));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,12 @@ public static class Builder extends FieldMapper.Builder {
private final Parameter<Map<String, String>> meta = Parameter.metaParam();

final IndexVersion indexVersionCreated;
final int mergeThreadCount;

public Builder(String name, IndexVersion indexVersionCreated) {
public Builder(String name, IndexVersion indexVersionCreated, int mergeThreadCount) {
super(name);
this.indexVersionCreated = indexVersionCreated;
this.mergeThreadCount = mergeThreadCount;
final boolean indexedByDefault = indexVersionCreated.onOrAfter(INDEXED_BY_DEFAULT_INDEX_VERSION);
final boolean defaultInt8Hnsw = indexVersionCreated.onOrAfter(DEFAULT_DENSE_VECTOR_TO_INT8_HNSW);
this.indexed = Parameter.indexParam(m -> toType(m).fieldType().indexed, indexedByDefault);
Expand Down Expand Up @@ -255,6 +257,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
),
indexOptions.getValue(),
indexVersionCreated,
mergeThreadCount,
multiFieldsBuilder.build(this, context),
copyTo
);
Expand Down Expand Up @@ -838,7 +841,7 @@ private abstract static class IndexOptions implements ToXContent {
this.type = type;
}

abstract KnnVectorsFormat getVectorsFormat();
abstract KnnVectorsFormat getVectorsFormat(int mergeThreadCount);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something about this really bugs me. The merge thread count is a dynamically updatable value. But doing this, is it really dynamic for the workers in the dense vector field mapper?


boolean supportsElementType(ElementType elementType) {
return true;
Expand Down Expand Up @@ -938,7 +941,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
KnnVectorsFormat getVectorsFormat() {
KnnVectorsFormat getVectorsFormat(int mergeThreadCount) {
return new ES813Int8FlatVectorFormat(confidenceInterval);
}

Expand Down Expand Up @@ -976,7 +979,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
KnnVectorsFormat getVectorsFormat() {
KnnVectorsFormat getVectorsFormat(int mergeThreadCount) {
return new ES813FlatVectorFormat();
}

Expand Down Expand Up @@ -1005,10 +1008,10 @@ private Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval
}

@Override
public KnnVectorsFormat getVectorsFormat() {
public KnnVectorsFormat getVectorsFormat(int mergeThreadCount) {
// int bits = 7;
// boolean compress = false; // TODO we only support 7 and false, for now
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, 1, confidenceInterval, null);
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, mergeThreadCount, confidenceInterval, null);
}

@Override
Expand Down Expand Up @@ -1067,8 +1070,8 @@ private HnswIndexOptions(int m, int efConstruction) {
}

@Override
public KnnVectorsFormat getVectorsFormat() {
return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null);
public KnnVectorsFormat getVectorsFormat(int mergeThreadCount) {
return new Lucene99HnswVectorsFormat(m, efConstruction, mergeThreadCount, null);
}

@Override
Expand Down Expand Up @@ -1101,7 +1104,7 @@ public String toString() {
}

public static final TypeParser PARSER = new TypeParser(
(n, c) -> new Builder(n, c.indexVersionCreated()),
(n, c) -> new Builder(n, c.indexVersionCreated(), c.getIndexSettings().getMergeSchedulerConfig().getMaxThreadCount()),
notInMultiFields(CONTENT_TYPE)
);

Expand Down Expand Up @@ -1394,18 +1397,21 @@ ElementType getElementType() {

private final IndexOptions indexOptions;
private final IndexVersion indexCreatedVersion;
private final int mergeThreadCount;

private DenseVectorFieldMapper(
String simpleName,
MappedFieldType mappedFieldType,
IndexOptions indexOptions,
IndexVersion indexCreatedVersion,
int mergeThreadCount,
MultiFields multiFields,
CopyTo copyTo
) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.indexOptions = indexOptions;
this.indexCreatedVersion = indexCreatedVersion;
this.mergeThreadCount = mergeThreadCount;
}

@Override
Expand Down Expand Up @@ -1448,6 +1454,7 @@ public void parse(DocumentParserContext context) throws IOException {
updatedDenseVectorFieldType,
indexOptions,
indexCreatedVersion,
mergeThreadCount,
multiFields(),
copyTo
);
Expand Down Expand Up @@ -1535,7 +1542,7 @@ protected String contentType() {

@Override
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), indexCreatedVersion).init(this);
return new Builder(simpleName(), indexCreatedVersion, mergeThreadCount).init(this);
}

private static IndexOptions parseIndexOptions(String fieldName, Object propNode) {
Expand All @@ -1560,7 +1567,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultForm
if (indexOptions == null) {
format = defaultFormat;
} else {
format = indexOptions.getVectorsFormat();
format = indexOptions.getVectorsFormat(mergeThreadCount);
}
// It's legal to reuse the same format name as this is the same on-disk format.
return new KnnVectorsFormat(format.getName()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
public static final String CONTENT_TYPE = "semantic_text";

public static final TypeParser PARSER = new TypeParser(
(n, c) -> new Builder(n, c.indexVersionCreated()),
(n, c) -> new Builder(n, c.indexVersionCreated(), c.getIndexSettings().getMergeSchedulerConfig().getMaxThreadCount()),
notInMultiFields(CONTENT_TYPE)
);

public static class Builder extends FieldMapper.Builder {
private final IndexVersion indexVersionCreated;
private final int mergeThreadCount;

private final Parameter<String> inferenceId = Parameter.stringParam(
"inference_id",
Expand All @@ -97,10 +98,11 @@ public static class Builder extends FieldMapper.Builder {

private Function<MapperBuilderContext, ObjectMapper> inferenceFieldBuilder;

public Builder(String name, IndexVersion indexVersionCreated) {
public Builder(String name, IndexVersion indexVersionCreated, int mergeThreadCount) {
super(name);
this.indexVersionCreated = indexVersionCreated;
this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, modelSettings.get());
this.mergeThreadCount = mergeThreadCount;
this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, mergeThreadCount, modelSettings.get());
}

public Builder setInferenceId(String id) {
Expand Down Expand Up @@ -148,6 +150,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
modelSettings.getValue(),
inferenceField,
indexVersionCreated,
mergeThreadCount,
meta.getValue()
),
copyTo
Expand All @@ -168,7 +171,7 @@ public Iterator<Mapper> iterator() {

@Override
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), fieldType().indexVersionCreated).init(this);
return new Builder(simpleName(), fieldType().indexVersionCreated, fieldType().mergeThreadCount).init(this);
}

@Override
Expand Down Expand Up @@ -203,7 +206,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio
final SemanticTextFieldMapper mapper;
if (fieldType().getModelSettings() == null) {
context.path().remove();
Builder builder = (Builder) new Builder(simpleName(), fieldType().indexVersionCreated).init(this);
Builder builder = (Builder) new Builder(simpleName(), fieldType().indexVersionCreated, fieldType().mergeThreadCount).init(this);
try {
mapper = builder.setModelSettings(field.inference().modelSettings())
.setInferenceId(field.inference().inferenceId())
Expand Down Expand Up @@ -270,20 +273,23 @@ public static class SemanticTextFieldType extends SimpleMappedFieldType {
private final SemanticTextField.ModelSettings modelSettings;
private final ObjectMapper inferenceField;
private final IndexVersion indexVersionCreated;
private final int mergeThreadCount;

public SemanticTextFieldType(
String name,
String modelId,
SemanticTextField.ModelSettings modelSettings,
ObjectMapper inferenceField,
IndexVersion indexVersionCreated,
int mergeThreadCount,
Map<String, String> meta
) {
super(name, false, false, false, TextSearchInfo.NONE, meta);
this.inferenceId = modelId;
this.modelSettings = modelSettings;
this.inferenceField = inferenceField;
this.indexVersionCreated = indexVersionCreated;
this.mergeThreadCount = mergeThreadCount;
}

@Override
Expand Down Expand Up @@ -331,35 +337,42 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext
private static ObjectMapper createInferenceField(
MapperBuilderContext context,
IndexVersion indexVersionCreated,
int mergeThreadCount,
@Nullable SemanticTextField.ModelSettings modelSettings
) {
return new ObjectMapper.Builder(INFERENCE_FIELD, Explicit.EXPLICIT_TRUE).dynamic(ObjectMapper.Dynamic.FALSE)
.add(createChunksField(indexVersionCreated, modelSettings))
.add(createChunksField(indexVersionCreated, mergeThreadCount, modelSettings))
.build(context);
}

private static NestedObjectMapper.Builder createChunksField(
IndexVersion indexVersionCreated,
int mergeThreadCount,
SemanticTextField.ModelSettings modelSettings
) {
NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder(CHUNKS_FIELD, indexVersionCreated);
chunksField.dynamic(ObjectMapper.Dynamic.FALSE);
KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder(CHUNKED_TEXT_FIELD, indexVersionCreated).indexed(false)
.docValues(false);
if (modelSettings != null) {
chunksField.add(createEmbeddingsField(indexVersionCreated, modelSettings));
chunksField.add(createEmbeddingsField(indexVersionCreated, mergeThreadCount, modelSettings));
}
chunksField.add(chunkTextField);
return chunksField;
}

private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCreated, SemanticTextField.ModelSettings modelSettings) {
private static Mapper.Builder createEmbeddingsField(
IndexVersion indexVersionCreated,
int mergeThreadCount,
SemanticTextField.ModelSettings modelSettings
) {
return switch (modelSettings.taskType()) {
case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD);
case TEXT_EMBEDDING -> {
DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder(
CHUNKED_EMBEDDINGS_FIELD,
indexVersionCreated
indexVersionCreated,
mergeThreadCount
);
SimilarityMeasure similarity = modelSettings.similarity();
if (similarity != null) {
Expand Down