Skip to content

Commit aa60000

Browse files
feat: [ML] Support binary embeddings from Amazon Bedrock Titan (elastic#125378)
1 parent 7c77ead commit aa60000

File tree

8 files changed

+388
-120
lines changed

8 files changed

+388
-120
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ public class AmazonBedrockConstants {
1111
public static final String ACCESS_KEY_FIELD = "access_key";
1212
public static final String SECRET_KEY_FIELD = "secret_key";
1313
public static final String REGION_FIELD = "region";
14-
public static final String MODEL_FIELD = "model";
14+
public static final String MODEL_FIELD = "model_id";
1515
public static final String PROVIDER_FIELD = "provider";
16+
public static final String EMBEDDING_TYPE_FIELD = "embedding_type";
1617

1718
public static final String TEMPERATURE_FIELD = "temperature";
1819
public static final String TOP_P_FIELD = "top_p";
@@ -24,4 +25,5 @@ public class AmazonBedrockConstants {
2425

2526
public static final int DEFAULT_MAX_CHUNK_SIZE = 2048;
2627

28+
private AmazonBedrockConstants() {}
2729
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java

+53-3
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,45 @@
2424
import java.util.EnumSet;
2525
import java.util.Map;
2626
import java.util.Objects;
27+
import java.util.Optional;
2728

2829
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum;
2930
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
31+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
3032
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD;
3133
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD;
3234
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD;
35+
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD;
3336

3437
public abstract class AmazonBedrockServiceSettings extends FilteredXContentObject implements ServiceSettings {
3538

3639
protected static final String AMAZON_BEDROCK_BASE_NAME = "amazon_bedrock";
3740

41+
public enum AmazonBedrockEmbeddingType {
42+
FLOAT,
43+
BINARY;
44+
45+
public static AmazonBedrockEmbeddingType fromString(String value) {
46+
return switch (value.toLowerCase()) {
47+
case "float" -> FLOAT;
48+
case "binary" -> BINARY;
49+
default -> throw new IllegalArgumentException("unknown value for embedding type: " + value);
50+
};
51+
}
52+
53+
@Override
54+
public String toString() {
55+
return name().toLowerCase();
56+
}
57+
}
58+
59+
protected static final AmazonBedrockEmbeddingType DEFAULT_EMBEDDING_TYPE = AmazonBedrockEmbeddingType.FLOAT;
60+
3861
protected final String region;
3962
protected final String model;
4063
protected final AmazonBedrockProvider provider;
4164
protected final RateLimitSettings rateLimitSettings;
65+
protected final AmazonBedrockEmbeddingType embeddingType;
4266

4367
// the default requests per minute are defined as per-model in the "Runtime quotas" on AWS
4468
// see: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html
@@ -69,34 +93,50 @@ protected static AmazonBedrockServiceSettings.BaseAmazonBedrockCommonSettings fr
6993
AMAZON_BEDROCK_BASE_NAME,
7094
context
7195
);
96+
AmazonBedrockEmbeddingType embeddingType = extractOptionalEnum(
97+
map,
98+
EMBEDDING_TYPE_FIELD,
99+
ModelConfigurations.SERVICE_SETTINGS,
100+
AmazonBedrockEmbeddingType::fromString,
101+
EnumSet.allOf(AmazonBedrockEmbeddingType.class),
102+
validationException
103+
).orElse(DEFAULT_EMBEDDING_TYPE);
72104

73-
return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings);
105+
return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings, embeddingType);
74106
}
75107

76108
protected record BaseAmazonBedrockCommonSettings(
77109
String region,
78110
String model,
79111
AmazonBedrockProvider provider,
80-
@Nullable RateLimitSettings rateLimitSettings
112+
@Nullable RateLimitSettings rateLimitSettings,
113+
AmazonBedrockEmbeddingType embeddingType
81114
) {}
82115

83116
protected AmazonBedrockServiceSettings(StreamInput in) throws IOException {
84117
this.region = in.readString();
85118
this.model = in.readString();
86119
this.provider = in.readEnum(AmazonBedrockProvider.class);
87120
this.rateLimitSettings = new RateLimitSettings(in);
121+
if (in.getTransportVersion().onOrAfter(TransportVersions.V_9_0_0)) { // Version set for BWC
122+
this.embeddingType = in.readEnum(AmazonBedrockEmbeddingType.class);
123+
} else {
124+
this.embeddingType = DEFAULT_EMBEDDING_TYPE;
125+
}
88126
}
89127

90128
protected AmazonBedrockServiceSettings(
91129
String region,
92130
String model,
93131
AmazonBedrockProvider provider,
94-
@Nullable RateLimitSettings rateLimitSettings
132+
@Nullable RateLimitSettings rateLimitSettings,
133+
AmazonBedrockEmbeddingType embeddingType
95134
) {
96135
this.region = Objects.requireNonNull(region);
97136
this.model = Objects.requireNonNull(model);
98137
this.provider = Objects.requireNonNull(provider);
99138
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
139+
this.embeddingType = Objects.requireNonNullElse(embeddingType, DEFAULT_EMBEDDING_TYPE);
100140
}
101141

102142
@Override
@@ -121,12 +161,19 @@ public RateLimitSettings rateLimitSettings() {
121161
return rateLimitSettings;
122162
}
123163

164+
public AmazonBedrockEmbeddingType embeddingType() {
165+
return embeddingType;
166+
}
167+
124168
@Override
125169
public void writeTo(StreamOutput out) throws IOException {
126170
out.writeString(region);
127171
out.writeString(model);
128172
out.writeEnum(provider);
129173
rateLimitSettings.writeTo(out);
174+
if (out.getTransportVersion().onOrAfter(TransportVersions.V_9_0_0)) { // Version set for BWC
175+
out.writeEnum(embeddingType);
176+
}
130177
}
131178

132179
public void addBaseXContent(XContentBuilder builder, Params params) throws IOException {
@@ -137,6 +184,9 @@ protected void addXContentFragmentOfExposedFields(XContentBuilder builder, Param
137184
builder.field(REGION_FIELD, region);
138185
builder.field(MODEL_FIELD, model);
139186
builder.field(PROVIDER_FIELD, provider.name());
187+
if (embeddingType != DEFAULT_EMBEDDING_TYPE) {
188+
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString());
189+
}
140190
rateLimitSettings.toXContent(builder, params);
141191
}
142192
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public static ToXContent createEntity(
3636
if (truncatedInput.size() > 1) {
3737
throw new ElasticsearchException("[input] cannot contain more than one string");
3838
}
39-
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
39+
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0), serviceSettings.embeddingType());
4040
}
4141
case COHERE -> {
4242
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@
1212

1313
import java.io.IOException;
1414
import java.util.Objects;
15+
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType;
16+
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD;
1517

16-
public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText) implements ToXContentObject {
18+
public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText, AmazonBedrockEmbeddingType embeddingType)
19+
implements ToXContentObject {
1720

1821
private static final String INPUT_TEXT_FIELD = "inputText";
1922

2023
public AmazonBedrockTitanEmbeddingsRequestEntity {
2124
Objects.requireNonNull(inputText);
25+
Objects.requireNonNull(embeddingType);
2226
}
2327

2428
@Override
2529
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
2630
builder.startObject();
2731
builder.field(INPUT_TEXT_FIELD, inputText);
32+
if (embeddingType == AmazonBedrockEmbeddingType.BINARY) {
33+
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString());
34+
}
2835
builder.endObject();
2936
return builder;
3037
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java

+49-15
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
import org.elasticsearch.xcontent.XContentParser;
1717
import org.elasticsearch.xcontent.XContentParserConfiguration;
1818
import org.elasticsearch.xcontent.XContentType;
19-
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
19+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
20+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBytesResults;
2021
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
2122
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
2223
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest;
2324
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest;
2425
import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponse;
26+
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType;
2527

2628
import java.io.IOException;
2729
import java.nio.charset.StandardCharsets;
2830
import java.util.List;
31+
import java.util.Base64;
2932

3033
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3134
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
@@ -42,13 +45,13 @@ public AmazonBedrockEmbeddingsResponse(InvokeModelResponse invokeModelResult) {
4245
@Override
4346
public InferenceServiceResults accept(AmazonBedrockRequest request) {
4447
if (request instanceof AmazonBedrockEmbeddingsRequest asEmbeddingsRequest) {
45-
return fromResponse(result, asEmbeddingsRequest.provider());
48+
return fromResponse(result, asEmbeddingsRequest);
4649
}
4750

4851
throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]");
4952
}
5053

51-
public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) {
54+
public static TextEmbeddingResults fromResponse(InvokeModelResponse response, AmazonBedrockEmbeddingsRequest request) {
5255
var charset = StandardCharsets.UTF_8;
5356
var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer()));
5457

@@ -61,28 +64,33 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons
6164
XContentParser.Token token = jsonParser.currentToken();
6265
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
6366

64-
var embeddingList = parseEmbeddings(jsonParser, provider);
65-
66-
return new TextEmbeddingFloatResults(embeddingList);
67+
var embeddingType = request.getServiceSettings().embeddingType();
68+
if (embeddingType == AmazonBedrockEmbeddingType.BINARY) {
69+
var embeddingList = parseBinaryEmbeddings(jsonParser, request.provider());
70+
return new TextEmbeddingBytesResults(embeddingList);
71+
} else {
72+
var embeddingList = parseFloatEmbeddings(jsonParser, request.provider());
73+
return new TextEmbeddingFloatResults(embeddingList);
74+
}
6775
} catch (IOException e) {
6876
throw new ElasticsearchException(e);
6977
}
7078
}
7179

72-
private static List<TextEmbeddingFloatResults.Embedding> parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
80+
private static List<TextEmbeddingResults.InferredValue> parseFloatEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
7381
throws IOException {
7482
switch (provider) {
7583
case AMAZONTITAN -> {
76-
return parseTitanEmbeddings(jsonParser);
84+
return parseTitanFloatEmbeddings(jsonParser);
7785
}
7886
case COHERE -> {
79-
return parseCohereEmbeddings(jsonParser);
87+
return parseCohereFloatEmbeddings(jsonParser);
8088
}
8189
default -> throw new IOException("Unsupported provider [" + provider + "]");
8290
}
8391
}
8492

85-
private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XContentParser parser) throws IOException {
93+
private static List<TextEmbeddingResults.InferredValue> parseTitanFloatEmbeddings(XContentParser parser) throws IOException {
8694
/*
8795
Titan response:
8896
{
@@ -92,11 +100,11 @@ private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XC
92100
*/
93101
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
94102
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
95-
var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
103+
TextEmbeddingResults.InferredValue embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
96104
return List.of(embeddingValues);
97105
}
98106

99-
private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(XContentParser parser) throws IOException {
107+
private static List<TextEmbeddingResults.InferredValue> parseCohereFloatEmbeddings(XContentParser parser) throws IOException {
100108
/*
101109
Cohere response:
102110
{
@@ -111,17 +119,43 @@ private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(X
111119
*/
112120
positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE);
113121

114-
List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
122+
List<TextEmbeddingResults.InferredValue> embeddingList = parseList(
115123
parser,
116-
AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem
124+
AmazonBedrockEmbeddingsResponse::parseCohereFloatEmbeddingsListItem
117125
);
118126

119127
return embeddingList;
120128
}
121129

122-
private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException {
130+
private static TextEmbeddingResults.InferredValue parseCohereFloatEmbeddingsListItem(XContentParser parser) throws IOException {
123131
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
124132
return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
125133
}
126134

135+
private static List<TextEmbeddingResults.InferredValue> parseBinaryEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider)
136+
throws IOException {
137+
switch (provider) {
138+
case AMAZONTITAN -> {
139+
return parseTitanBinaryEmbeddings(jsonParser);
140+
}
141+
default -> throw new IOException("Binary embeddings not supported for provider [" + provider + "]");
142+
}
143+
}
144+
145+
private static List<TextEmbeddingResults.InferredValue> parseTitanBinaryEmbeddings(XContentParser parser) throws IOException {
146+
/*
147+
Titan Binary response (structure assumed based on float version):
148+
{
149+
"embedding": "<base64-encoded-binary-data>",
150+
"inputTextTokenCount": int
151+
}
152+
*/
153+
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
154+
String base64Embedding = parser.text();
155+
byte[] embeddingBytes = Base64.getDecoder().decode(base64Embedding);
156+
157+
TextEmbeddingResults.InferredValue embeddingValue = TextEmbeddingBytesResults.Embedding.of(embeddingBytes);
158+
return List.of(embeddingValue);
159+
}
160+
127161
}

0 commit comments

Comments
 (0)