diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2ff8d96b77b4c..70dc845b40628 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -40,14 +41,14 @@ public class EmbeddingRequestChunker> { // Visible for testing record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { - public String chunkText() { + String chunkText() { return inputs.get(inputIndex).substring(chunk.start(), chunk.end()); } } public record BatchRequest(List requests) { - public List inputs() { - return requests.stream().map(Request::chunkText).collect(Collectors.toList()); + public Supplier> inputs() { + return () -> requests.stream().map(Request::chunkText).collect(Collectors.toList()); } } @@ -144,7 +145,7 @@ public List batchRequestsWithListeners(ActionListener { - private final BatchRequest request; + private BatchRequest request; DebatchingListener(BatchRequest request) { this.request = request; @@ -170,6 +171,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) { oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding) ); } + request = null; if (resultCount.incrementAndGet() == batchRequests.size()) { sendFinalResponse(); } @@ -197,6 +199,7 @@ public void onFailure(Exception e) { for (Request request : request.requests) { resultsErrors.set(request.inputIndex(), e); } + this.request = null; if (resultCount.incrementAndGet() == batchRequests.size()) { sendFinalResponse(); } @@ -208,6 +211,7 @@ private void sendFinalResponse() { for (int i = 0; i < resultEmbeddings.size(); i++) { if (resultsErrors.get(i) != null) { response.add(new ChunkedInferenceError(resultsErrors.get(i))); + resultsErrors.set(i, null); } else { response.add(mergeResultsWithInputs(i)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java index b43e5ab70e2f2..5e39b3fa6a321 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java @@ -33,7 +33,7 @@ public SingleInputSenderExecutableAction( @Override public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { - if (inferenceInputs.inputSize() > 1) { + if (inferenceInputs.isSingleInput() == false) { listener.onFailure( new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java index 58c952b9c556a..3bf227e1c2740 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -35,7 +35,8 @@ public List getInputs() { return this.input; } - public int inputSize() { - return input.size(); + @Override + public boolean isSingleInput() { + return input.size() == 1; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index d1d017f1c61c5..3f59386082d73 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -12,6 +12,7 @@ import java.util.List; import java.util.Objects; +import java.util.function.Supplier; public class EmbeddingsInput extends InferenceInputs { @@ -23,29 +24,38 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) { return (EmbeddingsInput) inferenceInputs; } - private final List input; - + private final Supplier> listSupplier; private final InputType inputType; public EmbeddingsInput(List input, @Nullable InputType inputType) { this(input, inputType, false); } + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + super(false); + this.listSupplier = Objects.requireNonNull(inputSupplier); + this.inputType = inputType; + } + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { super(stream); - this.input = Objects.requireNonNull(input); + Objects.requireNonNull(input); + this.listSupplier = () -> input; this.inputType = inputType; } public List getInputs() { - return this.input; + return this.listSupplier.get(); } public InputType getInputType() { return this.inputType; } - public int inputSize() { - return input.size(); + @Override + public boolean isSingleInput() { + // We can't measure the size of the input list without executing + // the supplier. + return false; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index 816d6550f9b04..89572b43bfdcc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -34,5 +34,5 @@ public boolean stream() { return stream; } - public abstract int inputSize(); + public abstract boolean isSingleInput(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index d755ac982ac31..a2526a2a293eb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -61,7 +61,8 @@ public Integer getTopN() { return topN; } - public int inputSize() { - return chunks.size(); + @Override + public boolean isSingleInput() { + return chunks.size() == 1; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index f4f0511a4cc1b..b779e8f968624 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -49,7 +49,7 @@ public UnifiedCompletionRequest getRequest() { return request; } - public int inputSize() { - return request.messages().size(); + public boolean isSingleInput() { + return request.messages().size() == 1; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 06e3c05a3d783..569d42c3bd187 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -1113,7 +1113,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft var inferenceRequest = buildInferenceRequest( esModel.mlNodeDeploymentId(), EmptyConfigUpdate.INSTANCE, - batch.batch().inputs(), + batch.batch().inputs().get(), inputType, timeout ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 929ed7e18e503..01890e0d0a356 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -49,45 +49,45 @@ public void testWhitespaceInput_SentenceChunker() { var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" ")); + assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" ")); } public void testBlankInput_WordChunker() { var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); + assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); } public void testBlankInput_SentenceChunker() { var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); + assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); } public void testInputThatDoesNotChunk_WordChunker() { var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { String input = "one chunk"; var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs(), contains(input)); + assertThat(batches.get(0).batch().inputs().get(), contains(input)); } public void testMultipleShortInputsAreSingleBatch() { @@ -95,7 +95,7 @@ public void testMultipleShortInputsAreSingleBatch() { var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.get(0).batch(); - assertEquals(batch.inputs(), inputs); + assertEquals(batch.inputs().get(), inputs); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); assertThat(request.chunkText(), equalTo(inputs.get(i))); @@ -115,20 +115,20 @@ public void testManyInputsMakeManyBatches() { var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(4)); - assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(3).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(3).batch().inputs().get(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get(9)); + assertEquals("input 0", batches.get(0).batch().inputs().get().get(0)); + assertEquals("input 9", batches.get(0).batch().inputs().get().get(9)); assertThat( - batches.get(1).batch().inputs(), + batches.get(1).batch().inputs().get(), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get(9)); - assertThat(batches.get(3).batch().inputs(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get().get(0)); + assertEquals("input 29", batches.get(2).batch().inputs().get().get(9)); + assertThat(batches.get(3).batch().inputs().get(), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { @@ -151,20 +151,20 @@ public void testChunkingSettingsProvided() { var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(4)); - assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch)); - assertThat(batches.get(3).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch)); + assertThat(batches.get(3).batch().inputs().get(), hasSize(1)); - assertEquals("input 0", batches.get(0).batch().inputs().get(0)); - assertEquals("input 9", batches.get(0).batch().inputs().get(9)); + assertEquals("input 0", batches.get(0).batch().inputs().get().get(0)); + assertEquals("input 9", batches.get(0).batch().inputs().get().get(9)); assertThat( - batches.get(1).batch().inputs(), + batches.get(1).batch().inputs().get(), contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19") ); - assertEquals("input 20", batches.get(2).batch().inputs().get(0)); - assertEquals("input 29", batches.get(2).batch().inputs().get(9)); - assertThat(batches.get(3).batch().inputs(), contains("input 30")); + assertEquals("input 20", batches.get(2).batch().inputs().get().get(0)); + assertEquals("input 29", batches.get(2).batch().inputs().get().get(9)); + assertThat(batches.get(3).batch().inputs().get(), contains("input 30")); List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { @@ -195,7 +195,7 @@ public void testLongInputChunkedOverMultipleBatches() { assertThat(batches, hasSize(2)); var batch = batches.get(0).batch(); - assertThat(batch.inputs(), hasSize(batchSize)); + assertThat(batch.inputs().get(), hasSize(batchSize)); assertThat(batch.requests(), hasSize(batchSize)); EmbeddingRequestChunker.Request request = batch.requests().get(0); @@ -212,7 +212,7 @@ public void testLongInputChunkedOverMultipleBatches() { } batch = batches.get(1).batch(); - assertThat(batch.inputs(), hasSize(4)); + assertThat(batch.inputs().get(), hasSize(4)); assertThat(batch.requests(), hasSize(4)); for (int requestIndex = 0; requestIndex < 2; requestIndex++) { @@ -254,9 +254,9 @@ public void testVeryLongInput_Sparse() { // there are 10002 inference requests, resulting in 2001 batches. assertThat(batches, hasSize(2001)); for (int i = 0; i < 2000; i++) { - assertThat(batches.get(i).batch().inputs(), hasSize(5)); + assertThat(batches.get(i).batch().inputs().get(), hasSize(5)); } - assertThat(batches.get(2000).batch().inputs(), hasSize(2)); + assertThat(batches.get(2000).batch().inputs().get(), hasSize(2)); // Produce inference results for each request, with just the token // "word" and increasing weights. @@ -339,9 +339,9 @@ public void testVeryLongInput_Float() { // there are 10002 inference requests, resulting in 2001 batches. assertThat(batches, hasSize(2001)); for (int i = 0; i < 2000; i++) { - assertThat(batches.get(i).batch().inputs(), hasSize(5)); + assertThat(batches.get(i).batch().inputs().get(), hasSize(5)); } - assertThat(batches.get(2000).batch().inputs(), hasSize(2)); + assertThat(batches.get(2000).batch().inputs().get(), hasSize(2)); // Produce inference results for each request, with increasing weights. float weight = 0f; @@ -423,9 +423,9 @@ public void testVeryLongInput_Byte() { // there are 10002 inference requests, resulting in 2001 batches. assertThat(batches, hasSize(2001)); for (int i = 0; i < 2000; i++) { - assertThat(batches.get(i).batch().inputs(), hasSize(5)); + assertThat(batches.get(i).batch().inputs().get(), hasSize(5)); } - assertThat(batches.get(2000).batch().inputs(), hasSize(2)); + assertThat(batches.get(2000).batch().inputs().get(), hasSize(2)); // Produce inference results for each request, with increasing weights. byte weight = 0; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 2061174813041..6f335ab32f01c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.junit.Before; import java.util.List; @@ -53,7 +54,7 @@ public void testOneInputIsValid() { var testRan = new AtomicBoolean(false); executableAction.execute( - mock(EmbeddingsInput.class), + new UnifiedChatInput(List.of("one"), "system", false), mock(TimeValue.class), ActionListener.wrap(success -> testRan.set(true), e -> fail(e, "Test failed.")) ); @@ -65,7 +66,7 @@ public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); var input = List.of("one", "two"); when(badInput.getInputs()).thenReturn(input); - when(badInput.inputSize()).thenReturn(input.size()); + when(badInput.isSingleInput()).thenReturn(false); var actualException = new AtomicReference(); executableAction.execute( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java index 1c0643739d410..d21dc29765087 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java @@ -18,7 +18,7 @@ public class UnifiedChatInputTests extends ESTestCase { public void testConvertsStringInputToMessages() { var a = new UnifiedChatInput(List.of("hello", "awesome"), "a role", true); - assertThat(a.inputSize(), Matchers.is(2)); + assertThat(a.isSingleInput(), Matchers.is(false)); assertThat( a.getRequest(), Matchers.is(