Skip to content

Revert semantic_text model registry changes #127075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/changelog/126629.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ private static Version parseUnchecked(String version) {
public static final IndexVersion SYNTHETIC_SOURCE_STORE_ARRAYS_NATIVELY_SCALED_FLOAT = def(9_020_0_00, Version.LUCENE_10_1_0);
public static final IndexVersion USE_LUCENE101_POSTINGS_FORMAT = def(9_021_0_00, Version.LUCENE_10_1_0);
public static final IndexVersion UPGRADE_TO_LUCENE_10_2_0 = def(9_022_00_0, Version.LUCENE_10_2_0);
public static final IndexVersion SEMANTIC_TEXT_DEFAULTS_TO_BBQ = def(9_023_0_00, Version.LUCENE_10_2_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,6 @@ public Builder elementType(ElementType elementType) {
return this;
}

public Builder indexOptions(IndexOptions indexOptions) {
this.indexOptions.setValue(indexOptions);
return this;
}

@Override
public DenseVectorFieldMapper build(MapperBuilderContext context) {
// Validate again here because the dimensions or element type could have been set programmatically,
Expand Down Expand Up @@ -1226,7 +1221,7 @@ public final String toString() {
public abstract VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersion, ElementType elementType);
}

public abstract static class IndexOptions implements ToXContent {
abstract static class IndexOptions implements ToXContent {
final VectorIndexType type;

IndexOptions(VectorIndexType type) {
Expand All @@ -1235,36 +1230,21 @@ public abstract static class IndexOptions implements ToXContent {

abstract KnnVectorsFormat getVectorsFormat(ElementType elementType);

public boolean validate(ElementType elementType, int dim, boolean throwOnError) {
return validateElementType(elementType, throwOnError) && validateDimension(dim, throwOnError);
}

public boolean validateElementType(ElementType elementType) {
return validateElementType(elementType, true);
}

final boolean validateElementType(ElementType elementType, boolean throwOnError) {
boolean validElementType = type.supportsElementType(elementType);
if (throwOnError && validElementType == false) {
final void validateElementType(ElementType elementType) {
if (type.supportsElementType(elementType) == false) {
throw new IllegalArgumentException(
"[element_type] cannot be [" + elementType.toString() + "] when using index type [" + type + "]"
);
}
return validElementType;
}

abstract boolean updatableTo(IndexOptions update);

public boolean validateDimension(int dim) {
return validateDimension(dim, true);
}

public boolean validateDimension(int dim, boolean throwOnError) {
boolean supportsDimension = type.supportsDimension(dim);
if (throwOnError && supportsDimension == false) {
throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim);
public void validateDimension(int dim) {
if (type.supportsDimension(dim)) {
return;
}
return supportsDimension;
throw new IllegalArgumentException(type.name + " only supports even dimensions; provided=" + dim);
}

abstract boolean doEquals(IndexOptions other);
Expand Down Expand Up @@ -1767,12 +1747,12 @@ boolean updatableTo(IndexOptions update) {

}

public static class Int8HnswIndexOptions extends QuantizedIndexOptions {
static class Int8HnswIndexOptions extends QuantizedIndexOptions {
private final int m;
private final int efConstruction;
private final Float confidenceInterval;

public Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) {
Int8HnswIndexOptions(int m, int efConstruction, Float confidenceInterval, RescoreVector rescoreVector) {
super(VectorIndexType.INT8_HNSW, rescoreVector);
this.m = m;
this.efConstruction = efConstruction;
Expand Down Expand Up @@ -1910,11 +1890,11 @@ public String toString() {
}
}

public static class BBQHnswIndexOptions extends QuantizedIndexOptions {
static class BBQHnswIndexOptions extends QuantizedIndexOptions {
private final int m;
private final int efConstruction;

public BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) {
BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVector) {
super(VectorIndexType.BBQ_HNSW, rescoreVector);
this.m = m;
this.efConstruction = efConstruction;
Expand Down Expand Up @@ -1956,14 +1936,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public boolean validateDimension(int dim, boolean throwOnError) {
boolean supportsDimension = type.supportsDimension(dim);
if (throwOnError && supportsDimension == false) {
throw new IllegalArgumentException(
type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim
);
public void validateDimension(int dim) {
if (type.supportsDimension(dim)) {
return;
}
return supportsDimension;
throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim);
}
}

Expand Down Expand Up @@ -2007,19 +1984,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

@Override
public boolean validateDimension(int dim, boolean throwOnError) {
boolean supportsDimension = type.supportsDimension(dim);
if (throwOnError && supportsDimension == false) {
throw new IllegalArgumentException(
type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim
);
public void validateDimension(int dim) {
if (type.supportsDimension(dim)) {
return;
}
return supportsDimension;
throw new IllegalArgumentException(type.name + " does not support dimensions fewer than " + BBQ_MIN_DIMS + "; provided=" + dim);
}

}

public record RescoreVector(float oversample) implements ToXContentObject {
record RescoreVector(float oversample) implements ToXContentObject {
static final String NAME = "rescore_vector";
static final String OVERSAMPLE = "oversample";

Expand Down Expand Up @@ -2338,10 +2311,6 @@ int getVectorDimensions() {
ElementType getElementType() {
return elementType;
}

public IndexOptions getIndexOptions() {
return indexOptions;
}
}

private final IndexOptions indexOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,18 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas
}
}

public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
}

/**
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
*/
public boolean canMergeWith(MinimalServiceSettings other) {
return taskType == other.taskType
&& Objects.equals(dimensions, other.dimensions)
&& similarity == other.similarity
&& elementType == other.elementType;
&& elementType == other.elementType
&& (service == null || service.equals(other.service));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,6 @@ protected final MapperService createMapperService(Settings settings, String mapp
return mapperService;
}

protected final MapperService createMapperService(IndexVersion indexVersion, Settings settings, XContentBuilder mappings)
throws IOException {
MapperService mapperService = createMapperService(indexVersion, settings, () -> true, mappings);
merge(mapperService, mappings);
return mapperService;
}

protected final MapperService createMapperService(IndexVersion version, XContentBuilder mapping) throws IOException {
return createMapperService(version, getIndexSettings(), () -> true, mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ public class InferencePlugin extends Plugin
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;

public InferencePlugin(Settings settings) {
Expand Down Expand Up @@ -261,8 +260,8 @@ public Collection<?> createComponents(PluginServices services) {
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);

modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
services.clusterService().addListener(modelRegistry.get());
ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
services.clusterService().addListener(modelRegistry);

if (inferenceServiceExtensions == null) {
inferenceServiceExtensions = new ArrayList<>();
Expand Down Expand Up @@ -300,7 +299,7 @@ public Collection<?> createComponents(PluginServices services) {
elasicInferenceServiceFactory.get(),
serviceComponents.get(),
inferenceServiceSettings,
modelRegistry.get(),
modelRegistry,
authorizationHandler
)
)
Expand All @@ -318,14 +317,14 @@ public Collection<?> createComponents(PluginServices services) {
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
serviceRegistry.init(services.client());
for (var service : serviceRegistry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(
services.clusterService(),
serviceRegistry,
modelRegistry.get(),
modelRegistry,
getLicenseState(),
services.indexingPressure()
);
Expand All @@ -335,7 +334,7 @@ public Collection<?> createComponents(PluginServices services) {
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));

components.add(serviceRegistry);
components.add(modelRegistry.get());
components.add(modelRegistry);
components.add(httpClientManager);
components.add(inferenceStats);

Expand Down Expand Up @@ -499,16 +498,11 @@ public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
}

// Overridable for tests
protected Supplier<ModelRegistry> getModelRegistry() {
return () -> modelRegistry.get();
}

@Override
public Map<String, Mapper.TypeParser> getMappers() {
return Map.of(
SemanticTextFieldMapper.CONTENT_TYPE,
SemanticTextFieldMapper.parser(getModelRegistry()),
SemanticTextFieldMapper.PARSER,
OffsetSourceFieldMapper.CONTENT_TYPE,
OffsetSourceFieldMapper.PARSER
);
Expand Down
Loading
Loading