Skip to content

Commit 7bedd0d

Browse files
authored
Expose model registry to SemanticTextFieldMapper (#126635) (#126817)
This change integrates the new model registry with the `SemanticTextFieldMapper`, allowing inference IDs to be eagerly resolved at parse time. It also preserves the existing lenient behavior: no error is thrown if the specified inference id does not exist, only a warning is logged.
1 parent bab975a commit 7bedd0d

File tree

7 files changed

+193
-26
lines changed

7 files changed

+193
-26
lines changed

server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,18 +242,13 @@ private static void validateFieldNotPresent(String field, Object fieldValue, Tas
242242
}
243243
}
244244

245-
public ModelConfigurations toModelConfigurations(String inferenceEntityId) {
246-
return new ModelConfigurations(inferenceEntityId, taskType, service == null ? UNKNOWN_SERVICE : service, this);
247-
}
248-
249245
/**
250246
* Checks if the given {@link MinimalServiceSettings} is equivalent to the current definition.
251247
*/
252248
public boolean canMergeWith(MinimalServiceSettings other) {
253249
return taskType == other.taskType
254250
&& Objects.equals(dimensions, other.dimensions)
255251
&& similarity == other.similarity
256-
&& elementType == other.elementType
257-
&& (service == null || service.equals(other.service));
252+
&& elementType == other.elementType;
258253
}
259254
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ public class InferencePlugin extends Plugin
199199
private final SetOnce<ElasticInferenceServiceComponents> elasticInferenceServiceComponents = new SetOnce<>();
200200
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
201201
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
202+
private final SetOnce<ModelRegistry> modelRegistry = new SetOnce<>();
202203
private List<InferenceServiceExtension> inferenceServiceExtensions;
203204

204205
public InferencePlugin(Settings settings) {
@@ -262,8 +263,8 @@ public Collection<?> createComponents(PluginServices services) {
262263
var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService());
263264
amazonBedrockFactory.set(amazonBedrockRequestSenderFactory);
264265

265-
ModelRegistry modelRegistry = new ModelRegistry(services.clusterService(), services.client());
266-
services.clusterService().addListener(modelRegistry);
266+
modelRegistry.set(new ModelRegistry(services.clusterService(), services.client()));
267+
services.clusterService().addListener(modelRegistry.get());
267268

268269
if (inferenceServiceExtensions == null) {
269270
inferenceServiceExtensions = new ArrayList<>();
@@ -301,7 +302,7 @@ public Collection<?> createComponents(PluginServices services) {
301302
elasicInferenceServiceFactory.get(),
302303
serviceComponents.get(),
303304
inferenceServiceSettings,
304-
modelRegistry,
305+
modelRegistry.get(),
305306
authorizationHandler
306307
)
307308
)
@@ -319,18 +320,23 @@ public Collection<?> createComponents(PluginServices services) {
319320
var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
320321
serviceRegistry.init(services.client());
321322
for (var service : serviceRegistry.getServices().values()) {
322-
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
323+
service.defaultConfigIds().forEach(modelRegistry.get()::addDefaultIds);
323324
}
324325
inferenceServiceRegistry.set(serviceRegistry);
325326

326-
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
327+
var actionFilter = new ShardBulkInferenceActionFilter(
328+
services.clusterService(),
329+
serviceRegistry,
330+
modelRegistry.get(),
331+
getLicenseState()
332+
);
327333
shardBulkInferenceActionFilter.set(actionFilter);
328334

329335
var meterRegistry = services.telemetryProvider().getMeterRegistry();
330336
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
331337

332338
components.add(serviceRegistry);
333-
components.add(modelRegistry);
339+
components.add(modelRegistry.get());
334340
components.add(httpClientManager);
335341
components.add(inferenceStats);
336342

@@ -497,11 +503,16 @@ public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
497503
return Map.of(SemanticInferenceMetadataFieldsMapper.NAME, SemanticInferenceMetadataFieldsMapper.PARSER);
498504
}
499505

506+
// Overridable for tests
507+
protected Supplier<ModelRegistry> getModelRegistry() {
508+
return modelRegistry::get;
509+
}
510+
500511
@Override
501512
public Map<String, Mapper.TypeParser> getMappers() {
502513
return Map.of(
503514
SemanticTextFieldMapper.CONTENT_TYPE,
504-
SemanticTextFieldMapper.PARSER,
515+
SemanticTextFieldMapper.parser(getModelRegistry()),
505516
OffsetSourceFieldMapper.CONTENT_TYPE,
506517
OffsetSourceFieldMapper.PARSER
507518
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.mapper;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.apache.lucene.index.FieldInfos;
1113
import org.apache.lucene.index.LeafReaderContext;
1214
import org.apache.lucene.search.DocIdSetIterator;
@@ -18,6 +20,7 @@
1820
import org.apache.lucene.search.join.BitSetProducer;
1921
import org.apache.lucene.search.join.ScoreMode;
2022
import org.apache.lucene.util.BitSet;
23+
import org.elasticsearch.ResourceNotFoundException;
2124
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
2225
import org.elasticsearch.common.Strings;
2326
import org.elasticsearch.common.bytes.BytesReference;
@@ -75,6 +78,7 @@
7578
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
7679
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
7780
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
81+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
7882

7983
import java.io.IOException;
8084
import java.io.UncheckedIOException;
@@ -89,6 +93,7 @@
8993
import java.util.Set;
9094
import java.util.function.BiConsumer;
9195
import java.util.function.Function;
96+
import java.util.function.Supplier;
9297

9398
import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
9499
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
@@ -112,6 +117,7 @@
112117
* A {@link FieldMapper} for semantic text fields.
113118
*/
114119
public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper {
120+
private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class);
115121
public static final NodeFeature SEMANTIC_TEXT_SEARCH_INFERENCE_ID = new NodeFeature("semantic_text.search_inference_id", true);
116122
public static final NodeFeature SEMANTIC_TEXT_DEFAULT_ELSER_2 = new NodeFeature("semantic_text.default_elser_2", true);
117123
public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix");
@@ -129,10 +135,12 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
129135
public static final String CONTENT_TYPE = "semantic_text";
130136
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
131137

132-
public static final TypeParser PARSER = new TypeParser(
133-
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()),
134-
List.of(validateParserContext(CONTENT_TYPE))
135-
);
138+
public static final TypeParser parser(Supplier<ModelRegistry> modelRegistry) {
139+
return new TypeParser(
140+
(n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings(), modelRegistry.get()),
141+
List.of(validateParserContext(CONTENT_TYPE))
142+
);
143+
}
136144

137145
public static BiConsumer<String, MappingParserContext> validateParserContext(String type) {
138146
return (n, c) -> {
@@ -144,6 +152,7 @@ public static BiConsumer<String, MappingParserContext> validateParserContext(Str
144152
}
145153

146154
public static class Builder extends FieldMapper.Builder {
155+
private final ModelRegistry modelRegistry;
147156
private final boolean useLegacyFormat;
148157

149158
private final Parameter<String> inferenceId = Parameter.stringParam(
@@ -201,14 +210,21 @@ public static Builder from(SemanticTextFieldMapper mapper) {
201210
Builder builder = new Builder(
202211
mapper.leafName(),
203212
mapper.fieldType().getChunksField().bitsetProducer(),
204-
mapper.fieldType().getChunksField().indexSettings()
213+
mapper.fieldType().getChunksField().indexSettings(),
214+
mapper.modelRegistry
205215
);
206216
builder.init(mapper);
207217
return builder;
208218
}
209219

210-
public Builder(String name, Function<Query, BitSetProducer> bitSetProducer, IndexSettings indexSettings) {
220+
public Builder(
221+
String name,
222+
Function<Query, BitSetProducer> bitSetProducer,
223+
IndexSettings indexSettings,
224+
ModelRegistry modelRegistry
225+
) {
211226
super(name);
227+
this.modelRegistry = modelRegistry;
212228
this.useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(indexSettings.getSettings()) == false;
213229
this.inferenceFieldBuilder = c -> createInferenceField(
214230
c,
@@ -266,9 +282,32 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
266282
if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
267283
throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
268284
}
285+
286+
if (modelSettings.get() == null) {
287+
try {
288+
var resolvedModelSettings = modelRegistry.getMinimalServiceSettings(inferenceId.get());
289+
if (resolvedModelSettings != null) {
290+
modelSettings.setValue(resolvedModelSettings);
291+
}
292+
} catch (ResourceNotFoundException exc) {
293+
// We allow the inference ID to be unregistered at this point.
294+
// This will delay the creation of sub-fields, so indexing and querying for this field won't work
295+
// until the corresponding inference endpoint is created.
296+
}
297+
}
298+
269299
if (modelSettings.get() != null) {
270300
validateServiceSettings(modelSettings.get());
301+
} else {
302+
logger.warn(
303+
"The field [{}] references an unknown inference ID [{}]. "
304+
+ "Indexing and querying this field will not work correctly until the corresponding "
305+
+ "inference endpoint is created.",
306+
leafName(),
307+
inferenceId.get()
308+
);
271309
}
310+
272311
final String fullName = context.buildFullName(leafName());
273312

274313
if (context.isInNestedContext()) {
@@ -289,7 +328,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
289328
useLegacyFormat,
290329
meta.getValue()
291330
),
292-
builderParams(this, context)
331+
builderParams(this, context),
332+
modelRegistry
293333
);
294334
}
295335

@@ -330,9 +370,17 @@ private SemanticTextFieldMapper copySettings(SemanticTextFieldMapper mapper, Map
330370
}
331371
}
332372

333-
private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams builderParams) {
373+
private final ModelRegistry modelRegistry;
374+
375+
private SemanticTextFieldMapper(
376+
String simpleName,
377+
MappedFieldType mappedFieldType,
378+
BuilderParams builderParams,
379+
ModelRegistry modelRegistry
380+
) {
334381
super(simpleName, mappedFieldType, builderParams);
335382
ensureMultiFields(builderParams.multiFields().iterator());
383+
this.modelRegistry = modelRegistry;
336384
}
337385

338386
private void ensureMultiFields(Iterator<FieldMapper> mappers) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,18 +139,18 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap)
139139
private static final String MODEL_ID_FIELD = "model_id";
140140
private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
141141

142-
private final ClusterService clusterService;
143142
private final OriginSettingClient client;
144143
private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
145144

146145
private final MasterServiceTaskQueue<MetadataTask> metadataTaskQueue;
147146
private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
148147
private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
149148

149+
private volatile Metadata lastMetadata;
150+
150151
public ModelRegistry(ClusterService clusterService, Client client) {
151152
this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
152153
this.defaultConfigIds = new ConcurrentHashMap<>();
153-
this.clusterService = clusterService;
154154
var executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>() {
155155
@Override
156156
public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
@@ -222,11 +222,17 @@ public void clearDefaultIds() {
222222
* @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
223223
*/
224224
public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
225+
synchronized (this) {
226+
assert lastMetadata != null : "initial cluster state not set yet";
227+
if (lastMetadata == null) {
228+
throw new IllegalStateException("initial cluster state not set yet");
229+
}
230+
}
225231
var config = defaultConfigIds.get(inferenceEntityId);
226232
if (config != null) {
227233
return config.settings();
228234
}
229-
var state = ModelRegistryMetadata.fromState(clusterService.state().metadata());
235+
var state = ModelRegistryMetadata.fromState(lastMetadata);
230236
var existing = state.getMinimalServiceSettings(inferenceEntityId);
231237
if (state.isUpgraded() && existing == null) {
232238
throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.");
@@ -931,6 +937,13 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
931937

932938
@Override
933939
public void clusterChanged(ClusterChangedEvent event) {
940+
if (lastMetadata == null || event.metadataChanged()) {
941+
// keep track of the last applied cluster state
942+
synchronized (this) {
943+
lastMetadata = event.state().metadata();
944+
}
945+
}
946+
934947
if (event.localNodeMaster() == false) {
935948
return;
936949
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.search.join.QueryBitSetProducer;
2525
import org.apache.lucene.search.join.ScoreMode;
2626
import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
27+
import org.elasticsearch.cluster.ClusterChangedEvent;
2728
import org.elasticsearch.common.CheckedBiConsumer;
2829
import org.elasticsearch.common.CheckedBiFunction;
2930
import org.elasticsearch.common.Strings;
@@ -62,14 +63,20 @@
6263
import org.elasticsearch.search.LeafNestedDocuments;
6364
import org.elasticsearch.search.NestedDocuments;
6465
import org.elasticsearch.search.SearchHit;
66+
import org.elasticsearch.test.ClusterServiceUtils;
67+
import org.elasticsearch.test.client.NoOpClient;
68+
import org.elasticsearch.threadpool.TestThreadPool;
6569
import org.elasticsearch.xcontent.XContentBuilder;
6670
import org.elasticsearch.xcontent.XContentType;
6771
import org.elasticsearch.xcontent.json.JsonXContent;
6872
import org.elasticsearch.xpack.core.XPackClientPlugin;
6973
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
7074
import org.elasticsearch.xpack.inference.InferencePlugin;
7175
import org.elasticsearch.xpack.inference.model.TestModel;
76+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
77+
import org.junit.After;
7278
import org.junit.AssumptionViolatedException;
79+
import org.junit.Before;
7380

7481
import java.io.IOException;
7582
import java.util.Collection;
@@ -79,6 +86,7 @@
7986
import java.util.Map;
8087
import java.util.Set;
8188
import java.util.function.BiConsumer;
89+
import java.util.function.Supplier;
8290

8391
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
8492
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD;
@@ -100,18 +108,43 @@
100108
public class SemanticTextFieldMapperTests extends MapperTestCase {
101109
private final boolean useLegacyFormat;
102110

111+
private TestThreadPool threadPool;
112+
103113
public SemanticTextFieldMapperTests(boolean useLegacyFormat) {
104114
this.useLegacyFormat = useLegacyFormat;
105115
}
106116

117+
@Before
118+
private void startThreadPool() {
119+
threadPool = createThreadPool();
120+
}
121+
122+
@After
123+
private void stopThreadPool() {
124+
threadPool.close();
125+
}
126+
107127
@ParametersFactory
108128
public static Iterable<Object[]> parameters() throws Exception {
109129
return List.of(new Object[] { true }, new Object[] { false });
110130
}
111131

112132
@Override
113133
protected Collection<? extends Plugin> getPlugins() {
114-
return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin());
134+
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
135+
var modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
136+
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
137+
@Override
138+
public boolean localNodeMaster() {
139+
return false;
140+
}
141+
});
142+
return List.of(new InferencePlugin(Settings.EMPTY) {
143+
@Override
144+
protected Supplier<ModelRegistry> getModelRegistry() {
145+
return () -> modelRegistry;
146+
}
147+
}, new XPackClientPlugin());
115148
}
116149

117150
private MapperService createMapperService(XContentBuilder mappings, boolean useLegacyFormat) throws IOException {

0 commit comments

Comments
 (0)