Skip to content

[8.x] Add max.chunks to EmbeddingRequestChunker to prevent OOM (#123150) #126383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/123150.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 123150
summary: Limit the number of chunks for semantic text to prevent high memory usage
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;

public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> chunks) implements ChunkedInference {
public record ChunkedInferenceEmbedding(List<EmbeddingResults.Chunk> chunks) implements ChunkedInference {

public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());
Expand All @@ -27,10 +27,7 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
results.add(
new ChunkedInferenceEmbedding(
List.of(
new SparseEmbeddingResults.Chunk(
sparseEmbeddingResults.embeddings().get(i).tokens(),
new TextOffset(0, inputs.get(i).length())
)
new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length()))
)
)
);
Expand All @@ -41,10 +38,10 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding

@Override
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
var asChunk = new ArrayList<Chunk>();
for (var chunk : chunks()) {
asChunk.add(chunk.toChunk(xcontent));
List<Chunk> chunkedInferenceChunks = new ArrayList<>();
for (EmbeddingResults.Chunk embeddingResultsChunk : chunks()) {
chunkedInferenceChunks.add(new Chunk(embeddingResultsChunk.offset(), embeddingResultsChunk.embedding().toBytesRef(xcontent)));
}
return asChunk.iterator();
return chunkedInferenceChunks.iterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContent;
Expand All @@ -19,31 +20,30 @@
* A call to the inference service may contain multiple input texts, so this results may
* contain multiple results.
*/
public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>>
extends
InferenceServiceResults {

/**
* A resulting embedding together with the offset into the input text.
*/
interface Chunk {
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;

ChunkedInference.TextOffset offset();
}
public interface EmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends InferenceServiceResults {

/**
* A resulting embedding for one of the input texts to the inference service.
*/
interface Embedding<C extends Chunk> {
interface Embedding<E extends Embedding<E>> {
/**
* Combines the resulting embedding with the offset into the input text into a chunk.
* Merges the existing embedding and provided embedding into a new embedding.
*/
C toChunk(ChunkedInference.TextOffset offset);
E merge(E embedding);

/**
* Serializes the embedding to bytes.
*/
BytesReference toBytesRef(XContent xContent) throws IOException;
}

/**
* The resulting list of embeddings for the input texts to the inference service.
*/
List<E> embeddings();

/**
* A resulting embedding together with the offset into the input text.
*/
record Chunk(Embedding<?> embedding, ChunkedInference.TextOffset offset) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -27,17 +26,17 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;

public record SparseEmbeddingResults(List<Embedding> embeddings)
implements
EmbeddingResults<SparseEmbeddingResults.Chunk, SparseEmbeddingResults.Embedding> {
public record SparseEmbeddingResults(List<Embedding> embeddings) implements EmbeddingResults<SparseEmbeddingResults.Embedding> {

public static final String NAME = "sparse_embedding_results";
public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
Expand Down Expand Up @@ -124,7 +123,7 @@ public record Embedding(List<WeightedToken> tokens, boolean isTruncated)
implements
Writeable,
ToXContentObject,
EmbeddingResults.Embedding<Chunk> {
EmbeddingResults.Embedding<Embedding> {

public static final String EMBEDDING = "embedding";
public static final String IS_TRUNCATED = "is_truncated";
Expand Down Expand Up @@ -175,18 +174,35 @@ public String toString() {
}

@Override
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(tokens, offset);
}
}

public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
public Embedding merge(Embedding embedding) {
// This code assumes that the tokens are sorted by weight in descending order.
// If that's not the case, the resulting merged embedding will be incorrect.
List<WeightedToken> mergedTokens = new ArrayList<>();
Set<String> seenTokens = new HashSet<>();
int i = 0;
int j = 0;
// TODO: maybe truncate tokens here when it's getting too large?
while (i < tokens().size() || j < embedding.tokens().size()) {
WeightedToken token;
if (i == tokens().size()) {
token = embedding.tokens().get(j++);
} else if (j == embedding.tokens().size()) {
token = tokens().get(i++);
} else if (tokens.get(i).weight() > embedding.tokens().get(j).weight()) {
token = tokens().get(i++);
} else {
token = embedding.tokens().get(j++);
}
if (seenTokens.add(token.token())) {
mergedTokens.add(token);
}
}
boolean mergedIsTruncated = isTruncated || embedding.isTruncated();
return new Embedding(mergedTokens, mergedIsTruncated);
}

private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
@Override
public BytesReference toBytesRef(XContent xContent) throws IOException {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startObject();
for (var weightedToken : tokens) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
* ]
* }
*/
// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
// Embedding.merge method for bits. TODO: implement a proper merge method
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
implements
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
Expand Down Expand Up @@ -48,9 +47,7 @@
* ]
* }
*/
public record TextEmbeddingByteResults(List<Embedding> embeddings)
implements
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
public static final String NAME = "text_embedding_service_byte_results";
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";

Expand Down Expand Up @@ -118,9 +115,20 @@ public int hashCode() {
return Objects.hash(embeddings);
}

public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> {
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
// embeddings should happen inbetween serializations.
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
implements
Writeable,
ToXContentObject,
EmbeddingResults.Embedding<Embedding> {

public static final String EMBEDDING = "embedding";

public Embedding(byte[] values) {
this(values, null, 1);
}

public Embedding(StreamInput in) throws IOException {
this(in.readByteArray());
}
Expand Down Expand Up @@ -187,25 +195,26 @@ public int hashCode() {
}

@Override
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
public Embedding merge(Embedding embedding) {
byte[] newValues = new byte[values.length];
int[] newSumMergedValues = new int[values.length];
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
for (int i = 0; i < values.length; i++) {
newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
+ (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
// closest byte instead of truncating.
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
}
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
}

private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
@Override
public BytesReference toBytesRef(XContent xContent) throws IOException {
XContentBuilder builder = XContentBuilder.builder(xContent);
builder.startArray();
for (byte v : value) {
builder.value(v);
for (byte value : values) {
builder.value(value);
}
builder.endArray();
return BytesReference.bytes(builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
Expand Down Expand Up @@ -53,9 +52,7 @@
* ]
* }
*/
public record TextEmbeddingFloatResults(List<Embedding> embeddings)
implements
TextEmbeddingResults<TextEmbeddingFloatResults.Chunk, TextEmbeddingFloatResults.Embedding> {
public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
public static final String NAME = "text_embedding_service_results";
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();

Expand Down Expand Up @@ -155,9 +152,19 @@ public int hashCode() {
return Objects.hash(embeddings);
}

public record Embedding(float[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> {
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
// embeddings should happen inbetween serializations.
public record Embedding(float[] values, int numberOfMergedEmbeddings)
implements
Writeable,
ToXContentObject,
EmbeddingResults.Embedding<Embedding> {
public static final String EMBEDDING = "embedding";

public Embedding(float[] values) {
this(values, 1);
}

public Embedding(StreamInput in) throws IOException {
this(in.readFloatArray());
}
Expand Down Expand Up @@ -221,25 +228,21 @@ public int hashCode() {
}

@Override
public Chunk toChunk(ChunkedInference.TextOffset offset) {
return new Chunk(values, offset);
}
}

public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {

public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
public Embedding merge(Embedding embedding) {
float[] mergedValues = new float[values.length];
for (int i = 0; i < values.length; i++) {
mergedValues[i] = (numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i])
/ (numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
}
return new Embedding(mergedValues, numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
}

/**
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
*/
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
@Override
public BytesReference toBytesRef(XContent xContent) throws IOException {
XContentBuilder b = XContentBuilder.builder(xContent);
b.startArray();
for (float v : value) {
b.value(v);
for (float value : values) {
b.value(value);
}
b.endArray();
return BytesReference.bytes(b);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

package org.elasticsearch.xpack.core.inference.results;

public interface TextEmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>>
extends
EmbeddingResults<C, E> {
public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {

/**
* Returns the first text embedding entry in the result list's array size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.results;
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;

import java.io.IOException;
import java.util.ArrayList;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.results;
package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -14,7 +14,6 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;

import java.io.IOException;
import java.util.ArrayList;
Expand Down
Loading