Skip to content

[ML] Integrate SageMaker with OpenAI Embeddings #126856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 1, 2025
Merged

Conversation

prwhelan
Copy link
Member

@prwhelan prwhelan commented Apr 15, 2025

Integrating with SageMaker.

Current design:

  • SageMaker accepts any byte payload, which can be text, csv, or json. api represents the structure of the payload that we will send, for example openai, elastic, common, probably cohere or huggingface as well.
  • api implementations are extensions of SageMakerSchemaPayload, which supports:
    • "extra" service and task settings specific to the payload structure, so cohere would require embedding_type and openai would require dimensions in the service_settings
    • conversion logic from model, service settings, task settings, and input to SdkBytes
    • conversion logic from responding SdkBytes to InferenceServiceResults
  • Everything else is tunneling, there are a number of base service_settings and task_settings that are independent of the api format that we will store and set
  • We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.

@prwhelan prwhelan added >enhancement :ml Machine learning Team:ML Meta label for the ML team v9.1.0 labels Apr 15, 2025
@elasticsearchmachine
Copy link
Collaborator

Hi @prwhelan, I've created a changelog YAML for you.

Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Just left a few thoughts.

return builder.endObject();
}

private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, might be helpful to have this in a utility class somewhere eventually because we have to do stuff like this a lot.

@prwhelan prwhelan changed the title [Draft][Not For Checkin] Current SageMaker work [ML] Integrate SageMaker with OpenAI Embeddings Apr 25, 2025
@prwhelan prwhelan marked this pull request as ready for review April 25, 2025 17:14
@prwhelan prwhelan requested a review from a team as a code owner April 25, 2025 17:14
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@prwhelan prwhelan added auto-backport Automatically create backport pull requests when merged v8.19.0 labels Apr 29, 2025
Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a reminder to add docs in the elasticsearch-specification repo.

return Collections.unmodifiableMap(configurationMap);
});
new LazyInitializable<>(
() -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would Map.of() work instead of using a stream?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see we're combining multiple streams in a separate place 👍

@davidkyle davidkyle self-requested a review April 30, 2025 11:31
Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

} else {
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));

public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
log.debug("Subscriber connecting to publisher.");
var publisher = holder.getAndSet(null).v1();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Map<String, Object> config,
ActionListener<Model> parsedModelListener
) {
ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
private static final int DEFAULT_BATCH_SIZE = 2048;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like is a big number. 2048 may be an optimal size for SageMaker but a batch this size would use quite a lot of memory and isn't sympathetic with how the inference API works

Map.entry(
API,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.")
.setLabel("Api")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.setLabel("Api")
.setLabel("API")

public final void testXContentRoundTrip() throws IOException {
var instance = createTestInstance();
var instanceAsMap = toMap(instance);
var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙇

@prwhelan prwhelan enabled auto-merge (squash) May 1, 2025 15:54
@prwhelan prwhelan merged commit 245f5ee into elastic:main May 1, 2025
16 of 17 checks passed
@elasticsearchmachine
Copy link
Collaborator

💔 Backport failed

Status Branch Result
8.19 Commit could not be cherrypicked due to conflicts

You can use sqren/backport to manually backport by running backport --upstream elastic/elasticsearch --pr 126856

prwhelan added a commit to prwhelan/elasticsearch that referenced this pull request May 1, 2025
Integrating with SageMaker.

Current design:
- SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well.
- `api` implementations are extensions of `SageMakerSchemaPayload`, which supports:
  - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings`
  - conversion logic from model, service settings, task settings, and input to `SdkBytes`
  - conversion logic from responding `SdkBytes` to `InferenceServiceResults`
- Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set
- We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
elasticsearchmachine pushed a commit that referenced this pull request May 1, 2025
Integrating with SageMaker.

Current design:
- SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well.
- `api` implementations are extensions of `SageMakerSchemaPayload`, which supports:
  - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings`
  - conversion logic from model, service settings, task settings, and input to `SdkBytes`
  - conversion logic from responding `SdkBytes` to `InferenceServiceResults`
- Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set
- We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
auto-backport Automatically create backport pull requests when merged backport pending >enhancement :ml Machine learning Team:ML Meta label for the ML team v8.19.0 v9.1.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants