Skip to content

[ML] Integrate SageMaker with OpenAI Embeddings #126856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126856.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126856
summary: "[Draft][Not For Checkin] Current `SageMaker` work"
area: Machine Learning
type: enhancement
issues: []
5 changes: 5 additions & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4902,6 +4902,11 @@
<sha256 value="da37cb021156b6aae5a30337e270a33a43817a64c59ca7aa4c39074cfda39a4b" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
<artifact name="sagemakerruntime-2.30.38.jar">
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
<artifact name="sdk-core-2.30.38.jar">
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public final List<String> validationErrors() {
return validationErrors;
}

public final void throwIfValidationErrorsExist() {
if (validationErrors().isEmpty() == false) {
throw this;
}
}

@Override
public final String getMessage() {
StringBuilder sb = new StringBuilder();
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ dependencies {

/* AWS SDK v2 */
implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
implementation ("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
requires org.reactivestreams;
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
requires software.amazon.awssdk.services.sagemakerruntime;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
Expand Down Expand Up @@ -157,6 +159,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
namedWriteables.addAll(SageMakerModel.namedWriteables());
namedWriteables.addAll(SageMakerSchemas.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

Expand Down Expand Up @@ -293,6 +297,7 @@ public Collection<?> createComponents(PluginServices services) {
services.threadPool()
);

var sageMakerSchemas = new SageMakerSchemas();
inferenceServices.add(
() -> List.of(
context -> new ElasticInferenceService(
Expand All @@ -301,6 +306,15 @@ public Collection<?> createComponents(PluginServices services) {
inferenceServiceSettings,
modelRegistry,
authorizationHandler
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
new SageMakerClient(
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
services.threadPool()
),
sageMakerSchemas,
services.threadPool()
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue;

import java.time.Duration;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -55,6 +56,10 @@ public int connectionTimeout() {
return connectionTimeout;
}

public Duration connectionTimeoutDuration() {
return Duration.ofMillis(connectionTimeout);
}

private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
this.maxResponseSize = maxResponseSize;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.sagemaker;

import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.reactivestreams.FlowAdapters;

import java.io.Closeable;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

public class SageMakerClient implements Closeable {
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
RegionAndSecrets,
SageMakerRuntimeAsyncClient>builder()
.removalListener(removal -> removal.getValue().close())
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
.build();

private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
private final ThreadPool threadPool;

public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
this.clientFactory = clientFactory;
this.threadPool = threadPool;
}

public void invoke(
RegionAndSecrets regionAndSecrets,
InvokeEndpointRequest request,
TimeValue timeout,
ActionListener<InvokeEndpointResponse> listener
) {
var asyncClient = getOrCreateClient(regionAndSecrets);
asyncClient.invokeEndpoint(request)
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
.thenAcceptAsync(listener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
.exceptionallyAsync(t -> failAndMaybeThrowError(t, listener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
if (t instanceof Exception e) {
listener.onFailure(e);
} else {
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker."));
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));

}
return null; // Void
}

public void invokeStream(
RegionAndSecrets regionAndSecrets,
InvokeEndpointWithResponseStreamRequest request,
TimeValue timeout,
ActionListener<SageMakerStream> listener
) {
var asyncClient = getOrCreateClient(regionAndSecrets);
var runOnceListener = ActionListener.notifyOnce(listener);
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
.onResponse(response -> runOnceListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
.build();
asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener)
.orTimeout(timeout.seconds(), TimeUnit.SECONDS)
.exceptionallyAsync(t -> failAndMaybeThrowError(t, runOnceListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

private SageMakerRuntimeAsyncClient getOrCreateClient(RegionAndSecrets regionAndSecrets) {
try {
return existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
} catch (ExecutionException e) {
throw new ElasticsearchException("failed to create SageMakerRuntime client", e);
}
}

@Override
public void close() {
existingClients.invalidateAll(); // will close each cached client
}

public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}

public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
private final HttpSettings httpSettings;

public Factory(HttpSettings httpSettings) {
this.httpSettings = httpSettings;
}

@Override
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
SpecialPermission.check();
// TODO migrate to entitlements
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
try (var accessKey = key.secretSettings().accessKey(); var secretKey = key.secretSettings().secretKey()) {
var credentials = AwsBasicCredentials.create(accessKey.toString(), secretKey.toString());
var credentialsProvider = StaticCredentialsProvider.create(credentials);
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
var override = ClientOverrideConfiguration.builder()
// disable profileFile, user credentials will always come from the configured Model Secrets
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
.defaultProfileFile(ProfileFile.aggregator().build())
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
.build();
return SageMakerRuntimeAsyncClient.builder()
.credentialsProvider(credentialsProvider)
.region(Region.of(key.region()))
.httpClientBuilder(clientConfig)
.overrideConfiguration(override)
.build();
}
});
}
}

private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
new AtomicReference<>(null);

@Override
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
log.debug("Subscriber connecting to publisher.");
var publisher = holder.getAndSet(null).v1();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

publisher.subscribe(subscriber);
} else {
log.debug("Subscriber waiting for connection.");
}
}

private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
log.debug("Publisher connecting to subscriber.");
var subscriber = holder.getAndSet(null).v2();
publisher.subscribe(subscriber);
} else {
log.debug("Publisher waiting for connection.");
}
}
}

public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.sagemaker;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;

import java.util.List;

public record SageMakerInferenceRequest(
String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
@Nullable List<String> input,
boolean stream,
InputType inputType
) {}
Loading
Loading