17
17
import org .elasticsearch .core .Nullable ;
18
18
import org .elasticsearch .core .TimeValue ;
19
19
import org .elasticsearch .inference .ChunkedInference ;
20
+ import org .elasticsearch .inference .ChunkingSettings ;
20
21
import org .elasticsearch .inference .EmptySecretSettings ;
21
22
import org .elasticsearch .inference .EmptyTaskSettings ;
22
23
import org .elasticsearch .inference .InferenceServiceConfiguration ;
36
37
import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceError ;
37
38
import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
38
39
import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
40
+ import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsBuilder ;
41
+ import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
39
42
import org .elasticsearch .xpack .inference .external .action .SenderExecutableAction ;
40
43
import org .elasticsearch .xpack .inference .external .http .sender .EmbeddingsInput ;
41
44
import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSender ;
68
71
import static org .elasticsearch .xpack .inference .services .ServiceFields .MODEL_ID ;
69
72
import static org .elasticsearch .xpack .inference .services .ServiceUtils .createInvalidModelException ;
70
73
import static org .elasticsearch .xpack .inference .services .ServiceUtils .parsePersistedConfigErrorMsg ;
74
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMap ;
71
75
import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrDefaultEmpty ;
72
76
import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeFromMapOrThrowIfNull ;
73
77
import static org .elasticsearch .xpack .inference .services .ServiceUtils .throwIfNotEmptyMap ;
@@ -77,6 +81,7 @@ public class ElasticInferenceService extends SenderService {
77
81
78
82
public static final String NAME = "elastic" ;
79
83
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service" ;
84
+ public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512 ;
80
85
81
86
private static final EnumSet <TaskType > IMPLEMENTED_TASK_TYPES = EnumSet .of (
82
87
TaskType .SPARSE_EMBEDDING ,
@@ -154,7 +159,8 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
154
159
new ElasticInferenceServiceSparseEmbeddingsServiceSettings (DEFAULT_ELSER_MODEL_ID_V2 , null , null ),
155
160
EmptyTaskSettings .INSTANCE ,
156
161
EmptySecretSettings .INSTANCE ,
157
- elasticInferenceServiceComponents
162
+ elasticInferenceServiceComponents ,
163
+ ChunkingSettingsBuilder .DEFAULT_SETTINGS
158
164
),
159
165
MinimalServiceSettings .sparseEmbedding (NAME )
160
166
)
@@ -284,12 +290,25 @@ protected void doChunkedInfer(
284
290
TimeValue timeout ,
285
291
ActionListener <List <ChunkedInference >> listener
286
292
) {
287
- // Pass-through without actually performing chunking (result will have a single chunk per input)
288
- ActionListener <InferenceServiceResults > inferListener = listener .delegateFailureAndWrap (
289
- (delegate , response ) -> delegate .onResponse (translateToChunkedResults (inputs , response ))
290
- );
293
+ if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel ) {
294
+ var actionCreator = new ElasticInferenceServiceActionCreator (getSender (), getServiceComponents (), getCurrentTraceInfo ());
295
+
296
+ List <EmbeddingRequestChunker .BatchRequestAndListener > batchedRequests = new EmbeddingRequestChunker <>(
297
+ inputs .getInputs (),
298
+ SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE ,
299
+ model .getConfigurations ().getChunkingSettings ()
300
+ ).batchRequestsWithListeners (listener );
301
+
302
+ for (var request : batchedRequests ) {
303
+ var action = sparseTextEmbeddingsModel .accept (actionCreator , taskSettings );
304
+ action .execute (EmbeddingsInput .fromStrings (request .batch ().inputs ().get (), inputType ), timeout , request .listener ());
305
+ }
306
+
307
+ return ;
308
+ }
291
309
292
- doInfer (model , inputs , taskSettings , timeout , inferListener );
310
+ // Model cannot perform chunked inference
311
+ listener .onFailure (createInvalidModelException (model ));
293
312
}
294
313
295
314
@ Override
@@ -308,6 +327,13 @@ public void parseRequestConfig(
308
327
Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
309
328
Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
310
329
330
+ ChunkingSettings chunkingSettings = null ;
331
+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
332
+ chunkingSettings = ChunkingSettingsBuilder .fromMap (
333
+ removeFromMapOrDefaultEmpty (config , ModelConfigurations .CHUNKING_SETTINGS )
334
+ );
335
+ }
336
+
311
337
ElasticInferenceServiceModel model = createModel (
312
338
inferenceEntityId ,
313
339
taskType ,
@@ -316,7 +342,8 @@ public void parseRequestConfig(
316
342
serviceSettingsMap ,
317
343
elasticInferenceServiceComponents ,
318
344
TaskType .unsupportedTaskTypeErrorMsg (taskType , NAME ),
319
- ConfigurationParseContext .REQUEST
345
+ ConfigurationParseContext .REQUEST ,
346
+ chunkingSettings
320
347
);
321
348
322
349
throwIfNotEmptyMap (config , NAME );
@@ -352,7 +379,8 @@ private static ElasticInferenceServiceModel createModel(
352
379
@ Nullable Map <String , Object > secretSettings ,
353
380
ElasticInferenceServiceComponents elasticInferenceServiceComponents ,
354
381
String failureMessage ,
355
- ConfigurationParseContext context
382
+ ConfigurationParseContext context ,
383
+ ChunkingSettings chunkingSettings
356
384
) {
357
385
return switch (taskType ) {
358
386
case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel (
@@ -363,7 +391,8 @@ private static ElasticInferenceServiceModel createModel(
363
391
taskSettings ,
364
392
secretSettings ,
365
393
elasticInferenceServiceComponents ,
366
- context
394
+ context ,
395
+ chunkingSettings
367
396
);
368
397
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel (
369
398
inferenceEntityId ,
@@ -400,13 +429,19 @@ public Model parsePersistedConfigWithSecrets(
400
429
Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
401
430
Map <String , Object > secretSettingsMap = removeFromMapOrDefaultEmpty (secrets , ModelSecrets .SECRET_SETTINGS );
402
431
432
+ ChunkingSettings chunkingSettings = null ;
433
+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
434
+ chunkingSettings = ChunkingSettingsBuilder .fromMap (removeFromMap (config , ModelConfigurations .CHUNKING_SETTINGS ));
435
+ }
436
+
403
437
return createModelFromPersistent (
404
438
inferenceEntityId ,
405
439
taskType ,
406
440
serviceSettingsMap ,
407
441
taskSettingsMap ,
408
442
secretSettingsMap ,
409
- parsePersistedConfigErrorMsg (inferenceEntityId , NAME )
443
+ parsePersistedConfigErrorMsg (inferenceEntityId , NAME ),
444
+ chunkingSettings
410
445
);
411
446
}
412
447
@@ -415,13 +450,19 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
415
450
Map <String , Object > serviceSettingsMap = removeFromMapOrThrowIfNull (config , ModelConfigurations .SERVICE_SETTINGS );
416
451
Map <String , Object > taskSettingsMap = removeFromMapOrDefaultEmpty (config , ModelConfigurations .TASK_SETTINGS );
417
452
453
+ ChunkingSettings chunkingSettings = null ;
454
+ if (TaskType .SPARSE_EMBEDDING .equals (taskType )) {
455
+ chunkingSettings = ChunkingSettingsBuilder .fromMap (removeFromMap (config , ModelConfigurations .CHUNKING_SETTINGS ));
456
+ }
457
+
418
458
return createModelFromPersistent (
419
459
inferenceEntityId ,
420
460
taskType ,
421
461
serviceSettingsMap ,
422
462
taskSettingsMap ,
423
463
null ,
424
- parsePersistedConfigErrorMsg (inferenceEntityId , NAME )
464
+ parsePersistedConfigErrorMsg (inferenceEntityId , NAME ),
465
+ chunkingSettings
425
466
);
426
467
}
427
468
@@ -436,7 +477,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
436
477
Map <String , Object > serviceSettings ,
437
478
Map <String , Object > taskSettings ,
438
479
@ Nullable Map <String , Object > secretSettings ,
439
- String failureMessage
480
+ String failureMessage ,
481
+ ChunkingSettings chunkingSettings
440
482
) {
441
483
return createModel (
442
484
inferenceEntityId ,
@@ -446,7 +488,8 @@ private ElasticInferenceServiceModel createModelFromPersistent(
446
488
secretSettings ,
447
489
elasticInferenceServiceComponents ,
448
490
failureMessage ,
449
- ConfigurationParseContext .PERSISTENT
491
+ ConfigurationParseContext .PERSISTENT ,
492
+ chunkingSettings
450
493
);
451
494
}
452
495
0 commit comments