|
14 | 14 | import org.elasticsearch.common.io.stream.StreamOutput;
|
15 | 15 | import org.elasticsearch.common.io.stream.Writeable;
|
16 | 16 | import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
17 |
| -import org.elasticsearch.inference.ChunkedInference; |
18 | 17 | import org.elasticsearch.inference.InferenceResults;
|
19 | 18 | import org.elasticsearch.inference.TaskType;
|
20 | 19 | import org.elasticsearch.rest.RestStatus;
|
|
27 | 26 |
|
28 | 27 | import java.io.IOException;
|
29 | 28 | import java.util.ArrayList;
|
| 29 | +import java.util.HashSet; |
30 | 30 | import java.util.Iterator;
|
31 | 31 | import java.util.LinkedHashMap;
|
32 | 32 | import java.util.List;
|
33 | 33 | import java.util.Map;
|
| 34 | +import java.util.Set; |
34 | 35 | import java.util.stream.Collectors;
|
35 | 36 |
|
36 | 37 | import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
|
37 | 38 |
|
38 |
| -public record SparseEmbeddingResults(List<Embedding> embeddings) |
39 |
| - implements |
40 |
| - EmbeddingResults<SparseEmbeddingResults.Chunk, SparseEmbeddingResults.Embedding> { |
| 39 | +public record SparseEmbeddingResults(List<Embedding> embeddings) implements EmbeddingResults<SparseEmbeddingResults.Embedding> { |
41 | 40 |
|
42 | 41 | public static final String NAME = "sparse_embedding_results";
|
43 | 42 | public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString();
|
@@ -124,7 +123,7 @@ public record Embedding(List<WeightedToken> tokens, boolean isTruncated)
|
124 | 123 | implements
|
125 | 124 | Writeable,
|
126 | 125 | ToXContentObject,
|
127 |
| - EmbeddingResults.Embedding<Chunk> { |
| 126 | + EmbeddingResults.Embedding<Embedding> { |
128 | 127 |
|
129 | 128 | public static final String EMBEDDING = "embedding";
|
130 | 129 | public static final String IS_TRUNCATED = "is_truncated";
|
@@ -175,18 +174,35 @@ public String toString() {
|
175 | 174 | }
|
176 | 175 |
|
177 | 176 | @Override
|
178 |
| - public Chunk toChunk(ChunkedInference.TextOffset offset) { |
179 |
| - return new Chunk(tokens, offset); |
180 |
| - } |
181 |
| - } |
182 |
| - |
183 |
| - public record Chunk(List<WeightedToken> weightedTokens, ChunkedInference.TextOffset offset) implements EmbeddingResults.Chunk { |
184 |
| - |
185 |
| - public ChunkedInference.Chunk toChunk(XContent xcontent) throws IOException { |
186 |
| - return new ChunkedInference.Chunk(offset, toBytesReference(xcontent, weightedTokens)); |
| 177 | + public Embedding merge(Embedding embedding) { |
| 178 | + // This code assumes that the tokens are sorted by weight in descending order. |
| 179 | + // If that's not the case, the resulting merged embedding will be incorrect. |
| 180 | + List<WeightedToken> mergedTokens = new ArrayList<>(); |
| 181 | + Set<String> seenTokens = new HashSet<>(); |
| 182 | + int i = 0; |
| 183 | + int j = 0; |
| 184 | + // TODO: maybe truncate tokens here when it's getting too large? |
| 185 | + while (i < tokens().size() || j < embedding.tokens().size()) { |
| 186 | + WeightedToken token; |
| 187 | + if (i == tokens().size()) { |
| 188 | + token = embedding.tokens().get(j++); |
| 189 | + } else if (j == embedding.tokens().size()) { |
| 190 | + token = tokens().get(i++); |
| 191 | + } else if (tokens.get(i).weight() > embedding.tokens().get(j).weight()) { |
| 192 | + token = tokens().get(i++); |
| 193 | + } else { |
| 194 | + token = embedding.tokens().get(j++); |
| 195 | + } |
| 196 | + if (seenTokens.add(token.token())) { |
| 197 | + mergedTokens.add(token); |
| 198 | + } |
| 199 | + } |
| 200 | + boolean mergedIsTruncated = isTruncated || embedding.isTruncated(); |
| 201 | + return new Embedding(mergedTokens, mergedIsTruncated); |
187 | 202 | }
|
188 | 203 |
|
189 |
| - private static BytesReference toBytesReference(XContent xContent, List<WeightedToken> tokens) throws IOException { |
| 204 | + @Override |
| 205 | + public BytesReference toBytesRef(XContent xContent) throws IOException { |
190 | 206 | XContentBuilder b = XContentBuilder.builder(xContent);
|
191 | 207 | b.startObject();
|
192 | 208 | for (var weightedToken : tokens) {
|
|
0 commit comments