-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ML] Integrate SageMaker with OpenAI Embeddings #126856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @prwhelan, I've created a changelog YAML for you. |
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Just left a few thoughts.
...ence/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java
Show resolved
Hide resolved
...ence/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java
Outdated
Show resolved
Hide resolved
...ava/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java
Show resolved
Hide resolved
return builder.endObject(); | ||
} | ||
|
||
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, might be helpful to have this in a utility class somewhere eventually because we have to do stuff like this a lot.
...c/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java
Show resolved
Hide resolved
...rg/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
Outdated
Show resolved
Hide resolved
...rg/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
...ence/src/test/java/org/elasticsearch/xpack/inference/services/InferenceSettingsTestCase.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
Pinging @elastic/ml-core (Team:ML) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a reminder to add docs in the elasticsearch-specification repo.
return Collections.unmodifiableMap(configurationMap); | ||
}); | ||
new LazyInitializable<>( | ||
() -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would Map.of()
work instead of using a stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see we're combining multiple streams in a separate place 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} else { | ||
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread); | ||
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker."); | ||
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.")); | |
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t)); |
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) { | ||
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) { | ||
log.debug("Subscriber connecting to publisher."); | ||
var publisher = holder.getAndSet(null).v1(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other implementations of this method call onError()
if a subscriber is already set, should this do the same?
Map<String, Object> config, | ||
ActionListener<Model> parsedModelListener | ||
) { | ||
ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
|
||
public class SageMakerService implements InferenceService { | ||
public static final String NAME = "sagemaker"; | ||
private static final int DEFAULT_BATCH_SIZE = 2048; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like is a big number. 2048 may be an optimal size for SageMaker but a batch this size would use quite a lot of memory and isn't sympathetic with how the inference API works
Map.entry( | ||
API, | ||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.") | ||
.setLabel("Api") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.setLabel("Api") | |
.setLabel("API") |
public final void testXContentRoundTrip() throws IOException { | ||
var instance = createTestInstance(); | ||
var instanceAsMap = toMap(instance); | ||
var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙇
💔 Backport failed
You can use sqren/backport to manually backport by running |
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Integrating with SageMaker.
Current design:
api
represents the structure of the payload that we will send, for exampleopenai
,elastic
,common
, probablycohere
orhuggingface
as well.api
implementations are extensions ofSageMakerSchemaPayload
, which supports:cohere
would requireembedding_type
andopenai
would requiredimensions
in theservice_settings
SdkBytes
SdkBytes
toInferenceServiceResults
service_settings
andtask_settings
that are independent of the api format that we will store and set