Skip to content

Commit a503497

Browse files
authored
Add max.chunks to EmbeddingRequestChunker to prevent OOM (elastic#123150)
* add max number of chunks * wire merge function * implement sparse merge function * move tests to correct package/file * float merge function * bytes merge function * more accurate byte average * spotless * Fix/improve EmbeddingRequestChunkerTests * Remove TODO * remove unnecessary field * remove Chunk generic * add TODO * Remove specialized chunks * add comment * Update docs/changelog/123150.yaml * update changelog
1 parent c24f77f commit a503497

File tree

77 files changed

+756
-355
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+756
-355
lines changed

docs/changelog/123150.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 123150
2+
summary: Limit the number of chunks for semantic text to prevent high memory usage
3+
area: Machine Learning
4+
type: feature
5+
issues: []

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbedding.java

+6-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings;
1919

20-
public record ChunkedInferenceEmbedding(List<? extends EmbeddingResults.Chunk> chunks) implements ChunkedInference {
20+
public record ChunkedInferenceEmbedding(List<EmbeddingResults.Chunk> chunks) implements ChunkedInference {
2121

2222
public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbeddingResults sparseEmbeddingResults) {
2323
validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size());
@@ -27,10 +27,7 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
2727
results.add(
2828
new ChunkedInferenceEmbedding(
2929
List.of(
30-
new SparseEmbeddingResults.Chunk(
31-
sparseEmbeddingResults.embeddings().get(i).tokens(),
32-
new TextOffset(0, inputs.get(i).length())
33-
)
30+
new EmbeddingResults.Chunk(sparseEmbeddingResults.embeddings().get(i), new TextOffset(0, inputs.get(i).length()))
3431
)
3532
)
3633
);
@@ -41,10 +38,10 @@ public static List<ChunkedInference> listOf(List<String> inputs, SparseEmbedding
4138

4239
@Override
4340
public Iterator<Chunk> chunksAsByteReference(XContent xcontent) throws IOException {
44-
var asChunk = new ArrayList<Chunk>();
45-
for (var chunk : chunks()) {
46-
asChunk.add(chunk.toChunk(xcontent));
41+
List<Chunk> chunkedInferenceChunks = new ArrayList<>();
42+
for (EmbeddingResults.Chunk embeddingResultsChunk : chunks()) {
43+
chunkedInferenceChunks.add(new Chunk(embeddingResultsChunk.offset(), embeddingResultsChunk.embedding().toBytesRef(xcontent)));
4744
}
48-
return asChunk.iterator();
45+
return chunkedInferenceChunks.iterator();
4946
}
5047
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java

+15-15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.results;
99

10+
import org.elasticsearch.common.bytes.BytesReference;
1011
import org.elasticsearch.inference.ChunkedInference;
1112
import org.elasticsearch.inference.InferenceServiceResults;
1213
import org.elasticsearch.xcontent.XContent;
@@ -19,31 +20,30 @@
1920
* A call to the inference service may contain multiple input texts, so this results may
2021
* contain multiple results.
2122
*/
22-
public interface EmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>>
23-
extends
24-
InferenceServiceResults {
25-
26-
/**
27-
* A resulting embedding together with the offset into the input text.
28-
*/
29-
interface Chunk {
30-
ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException;
31-
32-
ChunkedInference.TextOffset offset();
33-
}
23+
public interface EmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends InferenceServiceResults {
3424

3525
/**
3626
* A resulting embedding for one of the input texts to the inference service.
3727
*/
38-
interface Embedding<C extends Chunk> {
28+
interface Embedding<E extends Embedding<E>> {
3929
/**
40-
* Combines the resulting embedding with the offset into the input text into a chunk.
30+
* Merges the existing embedding and provided embedding into a new embedding.
4131
*/
42-
C toChunk(ChunkedInference.TextOffset offset);
32+
E merge(E embedding);
33+
34+
/**
35+
* Serializes the embedding to bytes.
36+
*/
37+
BytesReference toBytesRef(XContent xContent) throws IOException;
4338
}
4439

4540
/**
4641
* The resulting list of embeddings for the input texts to the inference service.
4742
*/
4843
List<E> embeddings();
44+
45+
/**
46+
* A resulting embedding together with the offset into the input text.
47+
*/
48+
record Chunk(Embedding<?> embedding, ChunkedInference.TextOffset offset) {}
4949
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java

+31-15
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.common.io.stream.Writeable;
1616
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
17-
import org.elasticsearch.inference.ChunkedInference;
1817
import org.elasticsearch.inference.InferenceResults;
1918
import org.elasticsearch.inference.TaskType;
2019
import org.elasticsearch.rest.RestStatus;
@@ -27,17 +26,17 @@
2726

2827
import java.io.IOException;
2928
import java.util.ArrayList;
29+
import java.util.HashSet;
3030
import java.util.Iterator;
3131
import java.util.LinkedHashMap;
3232
import java.util.List;
3333
import java.util.Map;
34+
import java.util.Set;
3435
import java.util.stream.Collectors;
3536

3637
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
3738

38-
public record SparseEmbeddingResults(List<Embedding> embeddings)
39-
implements
40-
EmbeddingResults<SparseEmbeddingResults.Chunk, SparseEmbeddingResults.Embedding> {
39+
public record SparseEmbeddingResults(List<Embedding> embeddings) implements EmbeddingResults<SparseEmbeddingResults.Embedding> {
4140

4241
public static final String NAME = "sparse_embedding_results";
4342
public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
@@ -124,7 +123,7 @@ public record Embedding(List<WeightedToken> tokens, boolean isTruncated)
124123
implements
125124
Writeable,
126125
ToXContentObject,
127-
EmbeddingResults.Embedding<Chunk> {
126+
EmbeddingResults.Embedding<Embedding> {
128127

129128
public static final String EMBEDDING = "embedding";
130129
public static final String IS_TRUNCATED = "is_truncated";
@@ -175,18 +174,35 @@ public String toString() {
175174
}
176175

177176
@Override
178-
public Chunk toChunk(ChunkedInference.TextOffset offset) {
179-
return new Chunk(tokens, offset);
180-
}
181-
}
182-
183-
public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
184-
185-
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
186-
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens));
177+
public Embedding merge(Embedding embedding) {
178+
// This code assumes that the tokens are sorted by weight in descending order.
179+
// If that's not the case, the resulting merged embedding will be incorrect.
180+
List<WeightedToken> mergedTokens = new ArrayList<>();
181+
Set<String> seenTokens = new HashSet<>();
182+
int i = 0;
183+
int j = 0;
184+
// TODO: maybe truncate tokens here when it's getting too large?
185+
while (i < tokens().size() || j < embedding.tokens().size()) {
186+
WeightedToken token;
187+
if (i == tokens().size()) {
188+
token = embedding.tokens().get(j++);
189+
} else if (j == embedding.tokens().size()) {
190+
token = tokens().get(i++);
191+
} else if (tokens.get(i).weight() > embedding.tokens().get(j).weight()) {
192+
token = tokens().get(i++);
193+
} else {
194+
token = embedding.tokens().get(j++);
195+
}
196+
if (seenTokens.add(token.token())) {
197+
mergedTokens.add(token);
198+
}
199+
}
200+
boolean mergedIsTruncated = isTruncated || embedding.isTruncated();
201+
return new Embedding(mergedTokens, mergedIsTruncated);
187202
}
188203

189-
private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException {
204+
@Override
205+
public BytesReference toBytesRef(XContent xContent) throws IOException {
190206
XContentBuilder b = XContentBuilder.builder(xContent);
191207
b.startObject();
192208
for (var weightedToken : tokens) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
* ]
4141
* }
4242
*/
43+
// Note: inheriting from TextEmbeddingByteResults gives a bad implementation of the
44+
// Embedding.merge method for bits. TODO: implement a proper merge method
4345
public record TextEmbeddingBitResults(List<TextEmbeddingByteResults.Embedding> embeddings)
4446
implements
45-
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
47+
TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
4648
public static final String NAME = "text_embedding_service_bit_results";
4749
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
4850

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java

+29-20
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.common.io.stream.Writeable;
1717
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
18-
import org.elasticsearch.inference.ChunkedInference;
1918
import org.elasticsearch.inference.InferenceResults;
2019
import org.elasticsearch.xcontent.ToXContent;
2120
import org.elasticsearch.xcontent.ToXContentObject;
@@ -48,9 +47,7 @@
4847
* ]
4948
* }
5049
*/
51-
public record TextEmbeddingByteResults(List<Embedding> embeddings)
52-
implements
53-
TextEmbeddingResults<TextEmbeddingByteResults.Chunk, TextEmbeddingByteResults.Embedding> {
50+
public record TextEmbeddingByteResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingByteResults.Embedding> {
5451
public static final String NAME = "text_embedding_service_byte_results";
5552
public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes";
5653

@@ -118,9 +115,20 @@ public int hashCode() {
118115
return Objects.hash(embeddings);
119116
}
120117

121-
public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> {
118+
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
119+
// embeddings should happen inbetween serializations.
120+
public record Embedding(byte[] values, int[] sumMergedValues, int numberOfMergedEmbeddings)
121+
implements
122+
Writeable,
123+
ToXContentObject,
124+
EmbeddingResults.Embedding<Embedding> {
125+
122126
public static final String EMBEDDING = "embedding";
123127

128+
public Embedding(byte[] values) {
129+
this(values, null, 1);
130+
}
131+
124132
public Embedding(StreamInput in) throws IOException {
125133
this(in.readByteArray());
126134
}
@@ -187,25 +195,26 @@ public int hashCode() {
187195
}
188196

189197
@Override
190-
public Chunk toChunk(ChunkedInference.TextOffset offset) {
191-
return new Chunk(values, offset);
192-
}
193-
}
194-
195-
/**
196-
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
197-
*/
198-
public record Chunk(byte[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
199-
200-
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
201-
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
198+
public Embedding merge(Embedding embedding) {
199+
byte[] newValues = new byte[values.length];
200+
int[] newSumMergedValues = new int[values.length];
201+
int newNumberOfMergedEmbeddings = numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings;
202+
for (int i = 0; i < values.length; i++) {
203+
newSumMergedValues[i] = (numberOfMergedEmbeddings == 1 ? values[i] : sumMergedValues[i])
204+
+ (embedding.numberOfMergedEmbeddings == 1 ? embedding.values[i] : embedding.sumMergedValues[i]);
205+
// Add (newNumberOfMergedEmbeddings / 2) in the numerator to round towards the
206+
// closest byte instead of truncating.
207+
newValues[i] = (byte) ((newSumMergedValues[i] + newNumberOfMergedEmbeddings / 2) / newNumberOfMergedEmbeddings);
208+
}
209+
return new Embedding(newValues, newSumMergedValues, newNumberOfMergedEmbeddings);
202210
}
203211

204-
private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException {
212+
@Override
213+
public BytesReference toBytesRef(XContent xContent) throws IOException {
205214
XContentBuilder builder = XContentBuilder.builder(xContent);
206215
builder.startArray();
207-
for (byte v : value) {
208-
builder.value(v);
216+
for (byte value : values) {
217+
builder.value(value);
209218
}
210219
builder.endArray();
211220
return BytesReference.bytes(builder);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingFloatResults.java

+23-20
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.elasticsearch.common.io.stream.StreamOutput;
1717
import org.elasticsearch.common.io.stream.Writeable;
1818
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
19-
import org.elasticsearch.inference.ChunkedInference;
2019
import org.elasticsearch.inference.InferenceResults;
2120
import org.elasticsearch.inference.TaskType;
2221
import org.elasticsearch.rest.RestStatus;
@@ -53,9 +52,7 @@
5352
* ]
5453
* }
5554
*/
56-
public record TextEmbeddingFloatResults(List<Embedding> embeddings)
57-
implements
58-
TextEmbeddingResults<TextEmbeddingFloatResults.Chunk, TextEmbeddingFloatResults.Embedding> {
55+
public record TextEmbeddingFloatResults(List<Embedding> embeddings) implements TextEmbeddingResults<TextEmbeddingFloatResults.Embedding> {
5956
public static final String NAME = "text_embedding_service_results";
6057
public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString();
6158

@@ -155,9 +152,19 @@ public int hashCode() {
155152
return Objects.hash(embeddings);
156153
}
157154

158-
public record Embedding(float[] values) implements Writeable, ToXContentObject, EmbeddingResults.Embedding<Chunk> {
155+
// Note: the field "numberOfMergedEmbeddings" is not serialized, so merging
156+
// embeddings should happen inbetween serializations.
157+
public record Embedding(float[] values, int numberOfMergedEmbeddings)
158+
implements
159+
Writeable,
160+
ToXContentObject,
161+
EmbeddingResults.Embedding<Embedding> {
159162
public static final String EMBEDDING = "embedding";
160163

164+
public Embedding(float[] values) {
165+
this(values, 1);
166+
}
167+
161168
public Embedding(StreamInput in) throws IOException {
162169
this(in.readFloatArray());
163170
}
@@ -221,25 +228,21 @@ public int hashCode() {
221228
}
222229

223230
@Override
224-
public Chunk toChunk(ChunkedInference.TextOffset offset) {
225-
return new Chunk(values, offset);
226-
}
227-
}
228-
229-
public record Chunk(float[] embedding, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk {
230-
231-
public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException {
232-
return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, embedding));
231+
public Embedding merge(Embedding embedding) {
232+
float[] mergedValues = new float[values.length];
233+
for (int i = 0; i < values.length; i++) {
234+
mergedValues[i] = (numberOfMergedEmbeddings * values[i] + embedding.numberOfMergedEmbeddings * embedding.values[i])
235+
/ (numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
236+
}
237+
return new Embedding(mergedValues, numberOfMergedEmbeddings + embedding.numberOfMergedEmbeddings);
233238
}
234239

235-
/**
236-
* Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}.
237-
*/
238-
private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException {
240+
@Override
241+
public BytesReference toBytesRef(XContent xContent) throws IOException {
239242
XContentBuilder b = XContentBuilder.builder(xContent);
240243
b.startArray();
241-
for (float v : value) {
242-
b.value(v);
244+
for (float value : values) {
245+
b.value(value);
243246
}
244247
b.endArray();
245248
return BytesReference.bytes(b);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
package org.elasticsearch.xpack.core.inference.results;
99

10-
public interface TextEmbeddingResults<C extends EmbeddingResults.Chunk, E extends EmbeddingResults.Embedding<C>>
11-
extends
12-
EmbeddingResults<C, E> {
10+
public interface TextEmbeddingResults<E extends EmbeddingResults.Embedding<E>> extends EmbeddingResults<E> {
1311

1412
/**
1513
* Returns the first text embedding entry in the result list's array size.
+1-2
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.results;
8+
package org.elasticsearch.xpack.core.inference.results;
99

1010
import org.elasticsearch.common.Strings;
1111
import org.elasticsearch.common.io.stream.Writeable;
1212
import org.elasticsearch.test.AbstractWireSerializingTestCase;
13-
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
1413

1514
import java.io.IOException;
1615
import java.util.ArrayList;
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.results;
8+
package org.elasticsearch.xpack.core.inference.results;
99

1010
import org.elasticsearch.common.Strings;
1111
import org.elasticsearch.common.io.stream.Writeable;
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.xcontent.XContentBuilder;
1515
import org.elasticsearch.xcontent.XContentFactory;
1616
import org.elasticsearch.xcontent.XContentType;
17-
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
1817

1918
import java.io.IOException;
2019
import java.util.ArrayList;

0 commit comments

Comments
 (0)