Skip to content

[8.x] [ML] Delay copying chunked input strings (#125837) #126402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand All @@ -40,14 +41,14 @@ public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {

// Visible for testing
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<String> inputs) {
public String chunkText() {
String chunkText() {
return inputs.get(inputIndex).substring(chunk.start(), chunk.end());
}
}

public record BatchRequest(List<Request> requests) {
public List<String> inputs() {
return requests.stream().map(Request::chunkText).collect(Collectors.toList());
public Supplier<List<String>> inputs() {
return () -> requests.stream().map(Request::chunkText).collect(Collectors.toList());
}
}

Expand Down Expand Up @@ -144,7 +145,7 @@ public List<BatchRequestAndListener> batchRequestsWithListeners(ActionListener<L
*/
private class DebatchingListener implements ActionListener<InferenceServiceResults> {

private final BatchRequest request;
private BatchRequest request;

DebatchingListener(BatchRequest request) {
this.request = request;
Expand All @@ -170,6 +171,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) {
oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding)
);
}
request = null;
if (resultCount.incrementAndGet() == batchRequests.size()) {
sendFinalResponse();
}
Expand Down Expand Up @@ -197,6 +199,7 @@ public void onFailure(Exception e) {
for (Request request : request.requests) {
resultsErrors.set(request.inputIndex(), e);
}
this.request = null;
if (resultCount.incrementAndGet() == batchRequests.size()) {
sendFinalResponse();
}
Expand All @@ -208,6 +211,7 @@ private void sendFinalResponse() {
for (int i = 0; i < resultEmbeddings.size(); i++) {
if (resultsErrors.get(i) != null) {
response.add(new ChunkedInferenceError(resultsErrors.get(i)));
resultsErrors.set(i, null);
} else {
response.add(mergeResultsWithInputs(i));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public SingleInputSenderExecutableAction(

@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs.inputSize() > 1) {
if (inferenceInputs.isSingleInput() == false) {
listener.onFailure(
new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public List<String> getInputs() {
return this.input;
}

public int inputSize() {
return input.size();
@Override
public boolean isSingleInput() {
return input.size() == 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class EmbeddingsInput extends InferenceInputs {

Expand All @@ -23,29 +24,38 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
return (EmbeddingsInput) inferenceInputs;
}

private final List<String> input;

private final Supplier<List<String>> listSupplier;
private final InputType inputType;

public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
this(input, inputType, false);
}

public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
super(false);
this.listSupplier = Objects.requireNonNull(inputSupplier);
this.inputType = inputType;
}

public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
super(stream);
this.input = Objects.requireNonNull(input);
Objects.requireNonNull(input);
this.listSupplier = () -> input;
this.inputType = inputType;
}

public List<String> getInputs() {
return this.input;
return this.listSupplier.get();
}

public InputType getInputType() {
return this.inputType;
}

public int inputSize() {
return input.size();
@Override
public boolean isSingleInput() {
// We can't measure the size of the input list without executing
// the supplier.
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ public boolean stream() {
return stream;
}

public abstract int inputSize();
public abstract boolean isSingleInput();
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public Integer getTopN() {
return topN;
}

public int inputSize() {
return chunks.size();
@Override
public boolean isSingleInput() {
return chunks.size() == 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public UnifiedCompletionRequest getRequest() {
return request;
}

public int inputSize() {
return request.messages().size();
public boolean isSingleInput() {
return request.messages().size() == 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
batch.batch().inputs().get(),
inputType,
timeout
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,53 +49,53 @@ public void testWhitespaceInput_SentenceChunker() {
var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" "));
}

public void testBlankInput_WordChunker() {
var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(""));
}

public void testBlankInput_SentenceChunker() {
var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(""));
}

public void testInputThatDoesNotChunk_WordChunker() {
var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA"));
}

public void testInputThatDoesNotChunk_SentenceChunker() {
var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA"));
}

public void testShortInputsAreSingleBatch() {
String input = "one chunk";
var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), contains(input));
assertThat(batches.get(0).batch().inputs().get(), contains(input));
}

public void testMultipleShortInputsAreSingleBatch() {
List<String> inputs = List.of("1st small", "2nd small", "3rd small");
var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
EmbeddingRequestChunker.BatchRequest batch = batches.get(0).batch();
assertEquals(batch.inputs(), inputs);
assertEquals(batch.inputs().get(), inputs);
for (int i = 0; i < inputs.size(); i++) {
var request = batch.requests().get(i);
assertThat(request.chunkText(), equalTo(inputs.get(i)));
Expand All @@ -115,20 +115,20 @@ public void testManyInputsMakeManyBatches() {

var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(4));
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(3).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(3).batch().inputs().get(), hasSize(1));

assertEquals("input 0", batches.get(0).batch().inputs().get(0));
assertEquals("input 9", batches.get(0).batch().inputs().get(9));
assertEquals("input 0", batches.get(0).batch().inputs().get().get(0));
assertEquals("input 9", batches.get(0).batch().inputs().get().get(9));
assertThat(
batches.get(1).batch().inputs(),
batches.get(1).batch().inputs().get(),
contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19")
);
assertEquals("input 20", batches.get(2).batch().inputs().get(0));
assertEquals("input 29", batches.get(2).batch().inputs().get(9));
assertThat(batches.get(3).batch().inputs(), contains("input 30"));
assertEquals("input 20", batches.get(2).batch().inputs().get().get(0));
assertEquals("input 29", batches.get(2).batch().inputs().get().get(9));
assertThat(batches.get(3).batch().inputs().get(), contains("input 30"));

List<EmbeddingRequestChunker.Request> requests = batches.get(0).batch().requests();
for (int i = 0; i < requests.size(); i++) {
Expand All @@ -151,20 +151,20 @@ public void testChunkingSettingsProvided() {
var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings())
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(4));
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(3).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
assertThat(batches.get(3).batch().inputs().get(), hasSize(1));

assertEquals("input 0", batches.get(0).batch().inputs().get(0));
assertEquals("input 9", batches.get(0).batch().inputs().get(9));
assertEquals("input 0", batches.get(0).batch().inputs().get().get(0));
assertEquals("input 9", batches.get(0).batch().inputs().get().get(9));
assertThat(
batches.get(1).batch().inputs(),
batches.get(1).batch().inputs().get(),
contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19")
);
assertEquals("input 20", batches.get(2).batch().inputs().get(0));
assertEquals("input 29", batches.get(2).batch().inputs().get(9));
assertThat(batches.get(3).batch().inputs(), contains("input 30"));
assertEquals("input 20", batches.get(2).batch().inputs().get().get(0));
assertEquals("input 29", batches.get(2).batch().inputs().get().get(9));
assertThat(batches.get(3).batch().inputs().get(), contains("input 30"));

List<EmbeddingRequestChunker.Request> requests = batches.get(0).batch().requests();
for (int i = 0; i < requests.size(); i++) {
Expand Down Expand Up @@ -195,7 +195,7 @@ public void testLongInputChunkedOverMultipleBatches() {
assertThat(batches, hasSize(2));

var batch = batches.get(0).batch();
assertThat(batch.inputs(), hasSize(batchSize));
assertThat(batch.inputs().get(), hasSize(batchSize));
assertThat(batch.requests(), hasSize(batchSize));

EmbeddingRequestChunker.Request request = batch.requests().get(0);
Expand All @@ -212,7 +212,7 @@ public void testLongInputChunkedOverMultipleBatches() {
}

batch = batches.get(1).batch();
assertThat(batch.inputs(), hasSize(4));
assertThat(batch.inputs().get(), hasSize(4));
assertThat(batch.requests(), hasSize(4));

for (int requestIndex = 0; requestIndex < 2; requestIndex++) {
Expand Down Expand Up @@ -254,9 +254,9 @@ public void testVeryLongInput_Sparse() {
// there are 10002 inference requests, resulting in 2001 batches.
assertThat(batches, hasSize(2001));
for (int i = 0; i < 2000; i++) {
assertThat(batches.get(i).batch().inputs(), hasSize(5));
assertThat(batches.get(i).batch().inputs().get(), hasSize(5));
}
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
assertThat(batches.get(2000).batch().inputs().get(), hasSize(2));

// Produce inference results for each request, with just the token
// "word" and increasing weights.
Expand Down Expand Up @@ -339,9 +339,9 @@ public void testVeryLongInput_Float() {
// there are 10002 inference requests, resulting in 2001 batches.
assertThat(batches, hasSize(2001));
for (int i = 0; i < 2000; i++) {
assertThat(batches.get(i).batch().inputs(), hasSize(5));
assertThat(batches.get(i).batch().inputs().get(), hasSize(5));
}
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
assertThat(batches.get(2000).batch().inputs().get(), hasSize(2));

// Produce inference results for each request, with increasing weights.
float weight = 0f;
Expand Down Expand Up @@ -423,9 +423,9 @@ public void testVeryLongInput_Byte() {
// there are 10002 inference requests, resulting in 2001 batches.
assertThat(batches, hasSize(2001));
for (int i = 0; i < 2000; i++) {
assertThat(batches.get(i).batch().inputs(), hasSize(5));
assertThat(batches.get(i).batch().inputs().get(), hasSize(5));
}
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
assertThat(batches.get(2000).batch().inputs().get(), hasSize(2));

// Produce inference results for each request, with increasing weights.
byte weight = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.junit.Before;

import java.util.List;
Expand Down Expand Up @@ -53,7 +54,7 @@ public void testOneInputIsValid() {
var testRan = new AtomicBoolean(false);

executableAction.execute(
mock(EmbeddingsInput.class),
new UnifiedChatInput(List.of("one"), "system", false),
mock(TimeValue.class),
ActionListener.wrap(success -> testRan.set(true), e -> fail(e, "Test failed."))
);
Expand All @@ -65,7 +66,7 @@ public void testMoreThanOneInput() {
var badInput = mock(EmbeddingsInput.class);
var input = List.of("one", "two");
when(badInput.getInputs()).thenReturn(input);
when(badInput.inputSize()).thenReturn(input.size());
when(badInput.isSingleInput()).thenReturn(false);
var actualException = new AtomicReference<Exception>();

executableAction.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class UnifiedChatInputTests extends ESTestCase {
public void testConvertsStringInputToMessages() {
var a = new UnifiedChatInput(List.of("hello", "awesome"), "a role", true);

assertThat(a.inputSize(), Matchers.is(2));
assertThat(a.isSingleInput(), Matchers.is(false));
assertThat(
a.getRequest(),
Matchers.is(
Expand Down