16
16
import org .elasticsearch .xcontent .XContentParser ;
17
17
import org .elasticsearch .xcontent .XContentParserConfiguration ;
18
18
import org .elasticsearch .xcontent .XContentType ;
19
- import org .elasticsearch .xpack .core .inference .results .TextEmbeddingFloatResults ;
19
+ import org .elasticsearch .xpack .core .inference .results .TextEmbeddingResults ;
20
+ import org .elasticsearch .xpack .core .inference .results .TextEmbeddingBytesResults ;
20
21
import org .elasticsearch .xpack .inference .external .response .XContentUtils ;
21
22
import org .elasticsearch .xpack .inference .services .amazonbedrock .AmazonBedrockProvider ;
22
23
import org .elasticsearch .xpack .inference .services .amazonbedrock .request .AmazonBedrockRequest ;
23
24
import org .elasticsearch .xpack .inference .services .amazonbedrock .request .embeddings .AmazonBedrockEmbeddingsRequest ;
24
25
import org .elasticsearch .xpack .inference .services .amazonbedrock .response .AmazonBedrockResponse ;
26
+ import org .elasticsearch .xpack .inference .services .amazonbedrock .AmazonBedrockServiceSettings .AmazonBedrockEmbeddingType ;
25
27
26
28
import java .io .IOException ;
27
29
import java .nio .charset .StandardCharsets ;
28
30
import java .util .List ;
31
+ import java .util .Base64 ;
29
32
30
33
import static org .elasticsearch .common .xcontent .XContentParserUtils .ensureExpectedToken ;
31
34
import static org .elasticsearch .common .xcontent .XContentParserUtils .parseList ;
@@ -42,13 +45,13 @@ public AmazonBedrockEmbeddingsResponse(InvokeModelResponse invokeModelResult) {
42
45
@ Override
43
46
public InferenceServiceResults accept (AmazonBedrockRequest request ) {
44
47
if (request instanceof AmazonBedrockEmbeddingsRequest asEmbeddingsRequest ) {
45
- return fromResponse (result , asEmbeddingsRequest . provider () );
48
+ return fromResponse (result , asEmbeddingsRequest );
46
49
}
47
50
48
51
throw new ElasticsearchException ("unexpected request type [" + request .getClass () + "]" );
49
52
}
50
53
51
- public static TextEmbeddingFloatResults fromResponse (InvokeModelResponse response , AmazonBedrockProvider provider ) {
54
+ public static TextEmbeddingResults fromResponse (InvokeModelResponse response , AmazonBedrockEmbeddingsRequest request ) {
52
55
var charset = StandardCharsets .UTF_8 ;
53
56
var bodyText = String .valueOf (charset .decode (response .body ().asByteBuffer ()));
54
57
@@ -61,28 +64,33 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons
61
64
XContentParser .Token token = jsonParser .currentToken ();
62
65
ensureExpectedToken (XContentParser .Token .START_OBJECT , token , jsonParser );
63
66
64
- var embeddingList = parseEmbeddings (jsonParser , provider );
65
-
66
- return new TextEmbeddingFloatResults (embeddingList );
67
+ var embeddingType = request .getServiceSettings ().embeddingType ();
68
+ if (embeddingType == AmazonBedrockEmbeddingType .BINARY ) {
69
+ var embeddingList = parseBinaryEmbeddings (jsonParser , request .provider ());
70
+ return new TextEmbeddingBytesResults (embeddingList );
71
+ } else {
72
+ var embeddingList = parseFloatEmbeddings (jsonParser , request .provider ());
73
+ return new TextEmbeddingFloatResults (embeddingList );
74
+ }
67
75
} catch (IOException e ) {
68
76
throw new ElasticsearchException (e );
69
77
}
70
78
}
71
79
72
- private static List <TextEmbeddingFloatResults . Embedding > parseEmbeddings (XContentParser jsonParser , AmazonBedrockProvider provider )
80
+ private static List <TextEmbeddingResults . InferredValue > parseFloatEmbeddings (XContentParser jsonParser , AmazonBedrockProvider provider )
73
81
throws IOException {
74
82
switch (provider ) {
75
83
case AMAZONTITAN -> {
76
- return parseTitanEmbeddings (jsonParser );
84
+ return parseTitanFloatEmbeddings (jsonParser );
77
85
}
78
86
case COHERE -> {
79
- return parseCohereEmbeddings (jsonParser );
87
+ return parseCohereFloatEmbeddings (jsonParser );
80
88
}
81
89
default -> throw new IOException ("Unsupported provider [" + provider + "]" );
82
90
}
83
91
}
84
92
85
- private static List <TextEmbeddingFloatResults . Embedding > parseTitanEmbeddings (XContentParser parser ) throws IOException {
93
+ private static List <TextEmbeddingResults . InferredValue > parseTitanFloatEmbeddings (XContentParser parser ) throws IOException {
86
94
/*
87
95
Titan response:
88
96
{
@@ -92,11 +100,11 @@ private static List<TextEmbeddingFloatResults.Embedding> parseTitanEmbeddings(XC
92
100
*/
93
101
positionParserAtTokenAfterField (parser , "embedding" , FAILED_TO_FIND_FIELD_TEMPLATE );
94
102
List <Float > embeddingValuesList = parseList (parser , XContentUtils ::parseFloat );
95
- var embeddingValues = TextEmbeddingFloatResults .Embedding .of (embeddingValuesList );
103
+ TextEmbeddingResults . InferredValue embeddingValues = TextEmbeddingFloatResults .Embedding .of (embeddingValuesList );
96
104
return List .of (embeddingValues );
97
105
}
98
106
99
- private static List <TextEmbeddingFloatResults . Embedding > parseCohereEmbeddings (XContentParser parser ) throws IOException {
107
+ private static List <TextEmbeddingResults . InferredValue > parseCohereFloatEmbeddings (XContentParser parser ) throws IOException {
100
108
/*
101
109
Cohere response:
102
110
{
@@ -111,17 +119,43 @@ private static List<TextEmbeddingFloatResults.Embedding> parseCohereEmbeddings(X
111
119
*/
112
120
positionParserAtTokenAfterField (parser , "embeddings" , FAILED_TO_FIND_FIELD_TEMPLATE );
113
121
114
- List <TextEmbeddingFloatResults . Embedding > embeddingList = parseList (
122
+ List <TextEmbeddingResults . InferredValue > embeddingList = parseList (
115
123
parser ,
116
- AmazonBedrockEmbeddingsResponse ::parseCohereEmbeddingsListItem
124
+ AmazonBedrockEmbeddingsResponse ::parseCohereFloatEmbeddingsListItem
117
125
);
118
126
119
127
return embeddingList ;
120
128
}
121
129
122
- private static TextEmbeddingFloatResults . Embedding parseCohereEmbeddingsListItem (XContentParser parser ) throws IOException {
130
+ private static TextEmbeddingResults . InferredValue parseCohereFloatEmbeddingsListItem (XContentParser parser ) throws IOException {
123
131
List <Float > embeddingValuesList = parseList (parser , XContentUtils ::parseFloat );
124
132
return TextEmbeddingFloatResults .Embedding .of (embeddingValuesList );
125
133
}
126
134
135
+ private static List <TextEmbeddingResults .InferredValue > parseBinaryEmbeddings (XContentParser jsonParser , AmazonBedrockProvider provider )
136
+ throws IOException {
137
+ switch (provider ) {
138
+ case AMAZONTITAN -> {
139
+ return parseTitanBinaryEmbeddings (jsonParser );
140
+ }
141
+ default -> throw new IOException ("Binary embeddings not supported for provider [" + provider + "]" );
142
+ }
143
+ }
144
+
145
+ private static List <TextEmbeddingResults .InferredValue > parseTitanBinaryEmbeddings (XContentParser parser ) throws IOException {
146
+ /*
147
+ Titan Binary response (structure assumed based on float version):
148
+ {
149
+ "embedding": "<base64-encoded-binary-data>",
150
+ "inputTextTokenCount": int
151
+ }
152
+ */
153
+ positionParserAtTokenAfterField (parser , "embedding" , FAILED_TO_FIND_FIELD_TEMPLATE );
154
+ String base64Embedding = parser .text ();
155
+ byte [] embeddingBytes = Base64 .getDecoder ().decode (base64Embedding );
156
+
157
+ TextEmbeddingResults .InferredValue embeddingValue = TextEmbeddingBytesResults .Embedding .of (embeddingBytes );
158
+ return List .of (embeddingValue );
159
+ }
160
+
127
161
}
0 commit comments