Skip to content

Commit 2a57e4e

Browse files
authored
[8.19] [EIS] Implement chunked & batched inference for sparse text embeddings (#129922) (#129958)
1 parent 82800f9 commit 2a57e4e

File tree

5 files changed

+71
-20
lines changed

5 files changed

+71
-20
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Nullable;
1818
import org.elasticsearch.core.TimeValue;
1919
import org.elasticsearch.inference.ChunkedInference;
20+
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.EmptySecretSettings;
2122
import org.elasticsearch.inference.EmptyTaskSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -36,6 +37,8 @@
3637
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
3738
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3839
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
40+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
41+
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3942
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
4043
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
4144
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -68,6 +71,7 @@
6871
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
6972
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
7073
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
74+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
7175
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
7276
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
7377
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -77,6 +81,7 @@ public class ElasticInferenceService extends SenderService {
7781

7882
public static final String NAME = "elastic";
7983
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
84+
public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
8085

8186
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
8287
TaskType.SPARSE_EMBEDDING,
@@ -154,7 +159,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
154159
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
155160
EmptyTaskSettings.INSTANCE,
156161
EmptySecretSettings.INSTANCE,
157-
elasticInferenceServiceComponents
162+
elasticInferenceServiceComponents,
163+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
158164
),
159165
MinimalServiceSettings.sparseEmbedding(NAME)
160166
)
@@ -284,12 +290,25 @@ protected void doChunkedInfer(
284290
TimeValue timeout,
285291
ActionListener<List<ChunkedInference>> listener
286292
) {
287-
// Pass-through without actually performing chunking (result will have a single chunk per input)
288-
ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
289-
(delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response))
290-
);
293+
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
294+
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
295+
296+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
297+
inputs.getInputs(),
298+
SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE,
299+
model.getConfigurations().getChunkingSettings()
300+
).batchRequestsWithListeners(listener);
301+
302+
for (var request : batchedRequests) {
303+
var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings);
304+
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
305+
}
306+
307+
return;
308+
}
291309

292-
doInfer(model, inputs, taskSettings, timeout, inferListener);
310+
// Model cannot perform chunked inference
311+
listener.onFailure(createInvalidModelException(model));
293312
}
294313

295314
@Override
@@ -308,6 +327,13 @@ public void parseRequestConfig(
308327
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
309328
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
310329

330+
ChunkingSettings chunkingSettings = null;
331+
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
332+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
333+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
334+
);
335+
}
336+
311337
ElasticInferenceServiceModel model = createModel(
312338
inferenceEntityId,
313339
taskType,
@@ -316,7 +342,8 @@ public void parseRequestConfig(
316342
serviceSettingsMap,
317343
elasticInferenceServiceComponents,
318344
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
319-
ConfigurationParseContext.REQUEST
345+
ConfigurationParseContext.REQUEST,
346+
chunkingSettings
320347
);
321348

322349
throwIfNotEmptyMap(config, NAME);
@@ -352,7 +379,8 @@ private static ElasticInferenceServiceModel createModel(
352379
@Nullable Map<String, Object> secretSettings,
353380
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
354381
String failureMessage,
355-
ConfigurationParseContext context
382+
ConfigurationParseContext context,
383+
ChunkingSettings chunkingSettings
356384
) {
357385
return switch (taskType) {
358386
case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
@@ -363,7 +391,8 @@ private static ElasticInferenceServiceModel createModel(
363391
taskSettings,
364392
secretSettings,
365393
elasticInferenceServiceComponents,
366-
context
394+
context,
395+
chunkingSettings
367396
);
368397
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
369398
inferenceEntityId,
@@ -400,13 +429,19 @@ public Model parsePersistedConfigWithSecrets(
400429
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
401430
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
402431

432+
ChunkingSettings chunkingSettings = null;
433+
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
434+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
435+
}
436+
403437
return createModelFromPersistent(
404438
inferenceEntityId,
405439
taskType,
406440
serviceSettingsMap,
407441
taskSettingsMap,
408442
secretSettingsMap,
409-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
443+
parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
444+
chunkingSettings
410445
);
411446
}
412447

@@ -415,13 +450,19 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
415450
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
416451
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
417452

453+
ChunkingSettings chunkingSettings = null;
454+
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
455+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
456+
}
457+
418458
return createModelFromPersistent(
419459
inferenceEntityId,
420460
taskType,
421461
serviceSettingsMap,
422462
taskSettingsMap,
423463
null,
424-
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
464+
parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
465+
chunkingSettings
425466
);
426467
}
427468

@@ -436,7 +477,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
436477
Map<String, Object> serviceSettings,
437478
Map<String, Object> taskSettings,
438479
@Nullable Map<String, Object> secretSettings,
439-
String failureMessage
480+
String failureMessage,
481+
ChunkingSettings chunkingSettings
440482
) {
441483
return createModel(
442484
inferenceEntityId,
@@ -446,7 +488,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
446488
secretSettings,
447489
elasticInferenceServiceComponents,
448490
failureMessage,
449-
ConfigurationParseContext.PERSISTENT
491+
ConfigurationParseContext.PERSISTENT,
492+
chunkingSettings
450493
);
451494
}
452495

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.EmptySecretSettings;
1314
import org.elasticsearch.inference.EmptyTaskSettings;
1415
import org.elasticsearch.inference.ModelConfigurations;
@@ -37,7 +38,8 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
3738
Map<String, Object> taskSettings,
3839
Map<String, Object> secrets,
3940
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
40-
ConfigurationParseContext context
41+
ConfigurationParseContext context,
42+
ChunkingSettings chunkingSettings
4143
) {
4244
this(
4345
inferenceEntityId,
@@ -46,7 +48,8 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
4648
ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context),
4749
EmptyTaskSettings.INSTANCE,
4850
EmptySecretSettings.INSTANCE,
49-
elasticInferenceServiceComponents
51+
elasticInferenceServiceComponents,
52+
chunkingSettings
5053
);
5154
}
5255

@@ -65,10 +68,11 @@ public ElasticInferenceServiceSparseEmbeddingsModel(
6568
ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings,
6669
@Nullable TaskSettings taskSettings,
6770
@Nullable SecretSettings secretSettings,
68-
ElasticInferenceServiceComponents elasticInferenceServiceComponents
71+
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
72+
ChunkingSettings chunkingSettings
6973
) {
7074
super(
71-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
75+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
7276
new ModelSecrets(secretSettings),
7377
serviceSettings,
7478
elasticInferenceServiceComponents

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.inference.EmptyTaskSettings;
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
1415

1516
public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase {
1617

@@ -26,7 +27,8 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2627
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
2728
EmptyTaskSettings.INSTANCE,
2829
EmptySecretSettings.INSTANCE,
29-
ElasticInferenceServiceComponents.of(url)
30+
ElasticInferenceServiceComponents.of(url),
31+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
3032
);
3133
}
3234
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws I
834834
}
835835
}
836836

837-
public void testChunkedInfer_PassesThrough() throws IOException {
837+
public void testChunkedInfer() throws IOException {
838838
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
839839
var elasticInferenceServiceURL = getUrl(webServer);
840840

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.test.ESSingleNodeTestCase;
2222
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
2323
import org.elasticsearch.xpack.inference.Utils;
24+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
2425
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2526
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2627
import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
@@ -196,7 +197,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints() {
196197
new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-v2", null, null),
197198
EmptyTaskSettings.INSTANCE,
198199
EmptySecretSettings.INSTANCE,
199-
ElasticInferenceServiceComponents.EMPTY_INSTANCE
200+
ElasticInferenceServiceComponents.EMPTY_INSTANCE,
201+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
200202
),
201203
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME)
202204
)

0 commit comments

Comments
 (0)