Skip to content

[8.x] Adding support for binary embedding type to Cohere service embedding type (#120751) #121584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/120751.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 120751
summary: Adding support for binary embedding type to Cohere service embedding type
area: Machine Learning
type: enhancement
issues: []
6 changes: 4 additions & 2 deletions docs/reference/inference/service-cohere.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ include::inference-shared.asciidoc[tag=chunking-settings-strategy]

`service`::
(Required, string)
The type of service supported for the specified task type. In this case,
The type of service supported for the specified task type. In this case,
`cohere`.

`service_settings`::
Expand Down Expand Up @@ -127,6 +127,8 @@ Valid values are:
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
* `float`: use it for the default float embeddings.
* `int8`: use it for signed int8 embeddings.
* `binary`: use it for binary embeddings, which are encoded as bytes with signed int8 precision.
* `bit`: use it for binary embeddings, which are encoded as bytes with signed int8 precision (this is a synonym of `binary`).

`model_id`:::
(Optional, string)
Expand Down Expand Up @@ -228,4 +230,4 @@ PUT _inference/rerank/cohere-rerank
// TEST[skip:TBD]

For more examples, also review the
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ static TransportVersion def(int id) {
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_0_00);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_01);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";

public InferenceByteEmbedding(StreamInput in) throws IOException {
this(in.readByteArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(values);
}

public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
byte[] embeddingValues = new byte[embeddingValuesList.size()];
for (int i = 0; i < embeddingValuesList.size(); i++) {
embeddingValues[i] = embeddingValuesList.get(i);
}
return new InferenceByteEmbedding(embeddingValues);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(EMBEDDING);
for (byte value : values) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

@Override
public String toString() {
return Strings.toString(this);
}

float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}

double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).doubleValue();
}
return doubleArray;
}

@Override
public int getSize() {
return values().length;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
return Arrays.equals(values, embedding.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a generative AI
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding_bytes": [
* {
* "embedding": [
* 23
* ]
* },
* {
* "embedding": [
* -23
* ]
* }
* ]
* }
*/
public record InferenceTextEmbeddingBitResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
public static final String NAME = "text_embedding_service_bit_results";
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";

public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
this(in.readCollectionAsList(InferenceByteEmbedding::new));
}

@Override
public int getFirstEmbeddingSize() {
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
return ChunkedToXContent.builder(params).array(TEXT_EMBEDDING_BITS, embeddings.iterator());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(embeddings);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
);

return List.of(legacyEmbedding);
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING_BITS, embeddings);

return map;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o;
return Objects.equals(embeddings, that.embeddings);
}

@Override
public int hashCode() {
return Objects.hash(embeddings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,16 @@

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -33,7 +28,7 @@
/**
* Writes a text embedding result in the follow json format
* {
* "text_embedding": [
* "text_embedding_bytes": [
* {
* "embedding": [
* 23
Expand Down Expand Up @@ -111,78 +106,4 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(embeddings);
}

public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
public static final String EMBEDDING = "embedding";

public InferenceByteEmbedding(StreamInput in) throws IOException {
this(in.readByteArray());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeByteArray(values);
}

public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
byte[] embeddingValues = new byte[embeddingValuesList.size()];
for (int i = 0; i < embeddingValuesList.size(); i++) {
embeddingValues[i] = embeddingValuesList.get(i);
}
return new InferenceByteEmbedding(embeddingValues);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(EMBEDDING);
for (byte value : values) {
builder.value(value);
}
builder.endArray();

builder.endObject();
return builder;
}

@Override
public String toString() {
return Strings.toString(this);
}

private float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}

private double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).floatValue();
}
return doubleArray;
}

@Override
public int getSize() {
return values().length;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
return Arrays.equals(values, embedding.values);
}

@Override
public int hashCode() {
return Arrays.hashCode(values);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand Down Expand Up @@ -69,7 +70,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El

private List<ChunkOffsetsAndInput> chunkedOffsets;
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
private AtomicArray<Exception> errors;
private ActionListener<List<ChunkedInference>> finalListener;
Expand Down Expand Up @@ -389,9 +390,9 @@ private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs(

private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
ChunkOffsetsAndInput chunks,
AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
) {
var all = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
var all = new ArrayList<InferenceByteEmbedding>();
for (int i = 0; i < debatchedResults.length(); i++) {
var subBatch = debatchedResults.get(i);
all.addAll(subBatch);
Expand Down
Loading