Skip to content

[8.19] Fix inference model validation for the semantic text field #127559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.xpack.inference.InferenceIndex;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand All @@ -45,7 +47,11 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static org.elasticsearch.xpack.inference.Utils.storeDenseModel;
import static org.elasticsearch.xpack.inference.Utils.storeModel;
import static org.elasticsearch.xpack.inference.Utils.storeSparseModel;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
Expand All @@ -56,6 +62,7 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {

private final boolean useLegacyFormat;
private final boolean useSyntheticSource;
private ModelRegistry modelRegistry;

public ShardBulkInferenceActionFilterIT(boolean useLegacyFormat, boolean useSyntheticSource) {
this.useLegacyFormat = useLegacyFormat;
Expand All @@ -74,16 +81,16 @@ public static Iterable<Object[]> parameters() throws Exception {

@Before
public void setup() throws Exception {
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
// dot product means that we need normalized vectors; it's not worth doing that in this test
SimilarityMeasure similarity = randomValueOtherThan(
SimilarityMeasure.DOT_PRODUCT,
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
);
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);
Utils.storeSparseModel(modelRegistry);
Utils.storeDenseModel(modelRegistry, dimensions, similarity, elementType);
storeSparseModel(modelRegistry);
storeDenseModel(modelRegistry, dimensions, similarity, elementType);
}

@Override
Expand Down Expand Up @@ -135,32 +142,131 @@ public void testBulkOperations() throws Exception {
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();
assertRandomBulkOperations(INDEX_NAME, isIndexRequest -> {
Map<String, Object> map = new HashMap<>();
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
map.put("dense_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
return map;
});
}

public void testItemFailures() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
int totalBulkSize = randomIntBetween(100, 200); // Use a bulk request size large enough to require batching
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
String id = Integer.toString(bulkSize);

// Set field values that will cause errors when generating inference requests
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));
source.put("dense_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));

bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
}

BulkResponse bulkResponse = bulkReqBuilder.get();
assertThat(bulkResponse.hasFailures(), equalTo(true));
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
assertThat(bulkItemResponse.isFailed(), equalTo(true));
assertThat(bulkItemResponse.getFailureMessage(), containsString("expected [String|Number|Boolean]"));
}
}

public void testRestart() throws Exception {
Model model1 = new TestSparseInferenceServiceExtension.TestSparseModel(
"another_inference_endpoint",
new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false)
);
storeModel(modelRegistry, model1);
prepareCreate("index_restart").setMapping("""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "new_inference_endpoint"
},
"other_field": {
"type": "semantic_text",
"inference_id": "another_inference_endpoint"
}
}
}
""").get();
Model model2 = new TestSparseInferenceServiceExtension.TestSparseModel(
"new_inference_endpoint",
new TestSparseInferenceServiceExtension.TestServiceSettings("sparse_model", null, false)
);
storeModel(modelRegistry, model2);

internalCluster().fullRestart(new InternalTestCluster.RestartCallback());
ensureGreen(InferenceIndex.INDEX_NAME, "index_restart");

assertRandomBulkOperations("index_restart", isIndexRequest -> {
Map<String, Object> map = new HashMap<>();
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
map.put("other_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
return map;
});

internalCluster().fullRestart(new InternalTestCluster.RestartCallback());
ensureGreen(InferenceIndex.INDEX_NAME, "index_restart");

assertRandomBulkOperations("index_restart", isIndexRequest -> {
Map<String, Object> map = new HashMap<>();
map.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
map.put("other_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
return map;
});
}

private void assertRandomBulkOperations(String indexName, Function<Boolean, Map<String, Object>> sourceSupplier) throws Exception {
int numHits = numHits(indexName);
int totalBulkReqs = randomIntBetween(2, 100);
long totalDocs = 0;
long totalDocs = numHits;
Set<String> ids = new HashSet<>();
for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) {

for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) {
BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
int totalBulkSize = randomIntBetween(1, 100);
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
if (ids.size() > 0 && rarely(random())) {
String id = randomFrom(ids);
ids.remove(id);
DeleteRequestBuilder request = new DeleteRequestBuilder(client(), INDEX_NAME).setId(id);
DeleteRequestBuilder request = new DeleteRequestBuilder(client(), indexName).setId(id);
bulkReqBuilder.add(request);
continue;
}
String id = Long.toString(totalDocs++);
boolean isIndexRequest = randomBoolean();
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
source.put("dense_field", isIndexRequest && rarely() ? null : randomSemanticTextInput());
Map<String, Object> source = sourceSupplier.apply(isIndexRequest);
if (isIndexRequest) {
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source));
ids.add(id);
} else {
boolean isUpsert = randomBoolean();
UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(INDEX_NAME).setDoc(source);
UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source);
if (isUpsert || ids.size() == 0) {
request.setDocAsUpsert(true);
} else {
Expand Down Expand Up @@ -188,59 +294,17 @@ public void testBulkOperations() throws Exception {
}
assertFalse(bulkResponse.hasFailures());
}
client().admin().indices().refresh(new RefreshRequest(indexName)).get();
assertThat(numHits(indexName), equalTo(ids.size() + numHits));
}

client().admin().indices().refresh(new RefreshRequest(INDEX_NAME)).get();

private int numHits(String indexName) throws Exception {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true);
SearchResponse searchResponse = client().search(new SearchRequest(INDEX_NAME).source(sourceBuilder)).get();
SearchResponse searchResponse = client().search(new SearchRequest(indexName).source(sourceBuilder)).get();
try {
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) ids.size()));
return (int) searchResponse.getHits().getTotalHits().value;
} finally {
searchResponse.decRef();
}
}

public void testItemFailures() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
int totalBulkSize = randomIntBetween(100, 200); // Use a bulk request size large enough to require batching
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
String id = Integer.toString(bulkSize);

// Set field values that will cause errors when generating inference requests
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));
source.put("dense_field", List.of(Map.of("foo", "bar"), Map.of("baz", "bar")));

bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
}

BulkResponse bulkResponse = bulkReqBuilder.get();
assertThat(bulkResponse.hasFailures(), equalTo(true));
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
assertThat(bulkItemResponse.isFailed(), equalTo(true));
assertThat(bulkItemResponse.getFailureMessage(), containsString("expected [String|Number|Boolean]"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MapperMergeContext;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MappingLookup;
import org.elasticsearch.index.mapper.MappingParserContext;
import org.elasticsearch.index.mapper.NestedObjectMapper;
Expand Down Expand Up @@ -204,6 +205,7 @@ public static class Builder extends FieldMapper.Builder {

private final Parameter<Map<String, String>> meta = Parameter.metaParam();

private MinimalServiceSettings resolvedModelSettings;
private Function<MapperBuilderContext, ObjectMapper> inferenceFieldBuilder;

public static Builder from(SemanticTextFieldMapper mapper) {
Expand Down Expand Up @@ -283,21 +285,31 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
}

if (modelSettings.get() == null) {
if (context.getMergeReason() != MapperService.MergeReason.MAPPING_RECOVERY && modelSettings.get() == null) {
try {
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
/*
* If the model is not already set and we are not in a recovery scenario, resolve it using the registry.
* Note: We do not set the model in the mapping at this stage. Instead, the model will be added through
* a mapping update during the first ingestion.
* This approach allows mappings to reference inference endpoints that may not yet exist.
* The only requirement is that the referenced inference endpoint must be available at the time of ingestion.
*/
resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
if (resolvedModelSettings != null) {
modelSettings.setValue(resolvedModelSettings);
validateServiceSettings(resolvedModelSettings, null);
}
} catch (ResourceNotFoundException exc) {
// We allow the inference ID to be unregistered at this point.
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
// until the corresponding inference endpoint is created.
/* We allow the inference ID to be unregistered at this point.
* This will delay the creation of sub-fields, so indexing and querying for this field won't work
* until the corresponding inference endpoint is created.
*/
}
} else {
resolvedModelSettings = modelSettings.get();
}

if (modelSettings.get() != null) {
validateServiceSettings(modelSettings.get());
validateServiceSettings(modelSettings.get(), resolvedModelSettings);
} else {
logger.warn(
"The field [{}] references an unknown inference ID [{}]. "
Expand Down Expand Up @@ -333,7 +345,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
);
}

private void validateServiceSettings(MinimalServiceSettings settings) {
private void validateServiceSettings(MinimalServiceSettings settings, MinimalServiceSettings resolved) {
switch (settings.taskType()) {
case SPARSE_EMBEDDING, TEXT_EMBEDDING -> {
}
Expand All @@ -348,6 +360,17 @@ private void validateServiceSettings(MinimalServiceSettings settings) {
+ settings.taskType().name()
);
}

if (resolved != null && settings.canMergeWith(resolved) == false) {
throw new IllegalArgumentException(
"Mismatch between provided and registered inference model settings. "
+ "Provided: ["
+ settings
+ "], Expected: ["
+ resolved
+ "]."
);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ public void clearDefaultIds() {
*/
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
synchronized (this) {
assert lastMetadata != null : "initial cluster state not set yet";
if (lastMetadata == null) {
throw new IllegalStateException("initial cluster state not set yet");
}
Expand Down
Loading