Skip to content

[ML] Integrate SageMaker with OpenAI Embeddings #126856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add transportversion, javadocs
  • Loading branch information
prwhelan committed Apr 22, 2025
commit f60d95105431e8fdf1d37396b53fc77faccd3e10
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ static TransportVersion def(int id) {
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -232,6 +233,7 @@ static TransportVersion def(int id) {
public static final TransportVersion BATCHED_QUERY_EXECUTION_DELAYABLE_WRITABLE = def(9_057_0_00);
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL = def(9_058_0_00);
public static final TransportVersion COMPRESS_DELAYABLE_WRITEABLE = def(9_059_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_060_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.core.Nullable;
Expand Down Expand Up @@ -267,7 +268,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> listen

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import java.util.Optional;

/**
* This model represents all models in SageMaker. SageMaker maintains a base set of settings and configurations, and this model manages
* those. Any settings that are required for a specific model are stored in the {@link SageMakerStoredServiceSchema} and
* {@link SageMakerStoredTaskSchema}.
* Design:
* - Region is stored in ServiceSettings and is used to create the SageMaker client.
* - RateLimiting is based on AWS Service Quota, metered by account and region. The SDK client handles rate limiting internally. In order to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.sagemaker.model;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -28,6 +29,10 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;

/**
* Maintains the settings for SageMaker that cannot be changed without impacting semantic search and AI assistants.
* Model-specific settings are stored in {@link SageMakerStoredServiceSchema}.
*/
record SageMakerServiceSettings(
String endpointName,
String region,
Expand Down Expand Up @@ -75,7 +80,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.sagemaker.model;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -22,6 +23,9 @@

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;

/**
* Maintains mutable settings for SageMaker. Model-specific settings are stored in {@link SageMakerStoredTaskSchema}.
*/
record SageMakerTaskSettings(
@Nullable String customAttributes,
@Nullable String enableExplanations,
Expand Down Expand Up @@ -92,7 +96,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
import java.util.Map;
import java.util.stream.Stream;

/**
* All the logic that is required to call any SageMaker model is handled within this Schema class.
* Any model-specific logic is handled within the associated {@link SageMakerSchemaPayload}.
* This schema is specific for SageMaker's non-streaming API. For streaming, see {@link SageMakerStreamSchema}.
*/
public class SageMakerSchema {
private static final String INTERNAL_DEPENDENCY_ERROR = "Received an internal dependency error from SageMaker for [%s]";
private static final String INTERNAL_FAILURE = "Received an internal failure from SageMaker for [%s]";
Expand Down Expand Up @@ -128,10 +133,6 @@ protected Tuple<String, RestStatus> errorMessageAndStatus(SageMakerModel model,
return Tuple.tuple(errorMessage, restStatus);
}

public String api() {
return schemaPayload.api();
}

public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
return schemaPayload.apiServiceSettings(serviceSettings, validationException);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,39 @@
import java.util.stream.Stream;

public interface SageMakerSchemaPayload {

/**
* The model API keyword that users will supply in the service settings when creating the request.
* Automatically registered in {@link SageMakerSchemas}.
*/
String api();

/**
* The supported TaskTypes for this model API.
* Automatically registered in {@link SageMakerSchemas}.
*/
EnumSet<TaskType> supportedTasks();

/**
* Implement this if the model requires extra ServiceSettings that can be saved to the model index.
* This can be accessed via {@link SageMakerModel#apiServiceSettings()}.
*/
default SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
return SageMakerStoredServiceSchema.NO_OP;
}

/**
* Implement this if the model requires extra TaskSettings that can be saved to the model index.
* This can be accessed via {@link SageMakerModel#apiTaskSettings()}.
*/
default SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
return SageMakerStoredTaskSchema.NO_OP;
}

/**
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
* This must be thrown if {@link SageMakerModel#apiServiceSettings()} or {@link SageMakerModel#apiTaskSettings()} return the wrong
* object types.
*/
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.of();
}

default Exception createUnsupportedSchemaException(SageMakerModel model) {
return new IllegalArgumentException(
Strings.format(
Expand All @@ -54,12 +68,31 @@ default Exception createUnsupportedSchemaException(SageMakerModel model) {
);
}

/**
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
*/
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.of();
}

/**
* The MIME type of the response from SageMaker.
*/
String accept(SageMakerModel model);

/**
* The MIME type of the request to SageMaker.
*/
String contentType(SageMakerModel model);

/**
* Translate to the body of the request in the MIME type specified by {@link #contentType(SageMakerModel)}.
*/
SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception;

/**
* Translate from the body of the response in the MIME type specified by {@link #accept(SageMakerModel)}.
*/
InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception;

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

import static org.elasticsearch.core.Strings.format;

/**
* The mapping and registry for all supported model API.
*/
public class SageMakerSchemas {
private static final Map<TaskAndApi, SageMakerSchema> schemas;
private static final Map<TaskAndApi, SageMakerStreamSchema> streamSchemas;
Expand All @@ -33,6 +36,9 @@ public class SageMakerSchemas {
private static final EnumSet<TaskType> supportedTaskTypes;

static {
/*
* Add new model API to the register call.
*/
schemas = register(new OpenAiTextEmbeddingPayload());

streamSchemas = schemas.entrySet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
package org.elasticsearch.xpack.inference.services.sagemaker.schema;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;

/**
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
*/
public interface SageMakerStoredServiceSchema extends ToXContentFragment, VersionedNamedWriteable {
SageMakerStoredServiceSchema NO_OP = new SageMakerStoredServiceSchema() {
private static final String NAME = "noop_sagemaker_service_schema";
Expand All @@ -24,7 +28,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.sagemaker.schema;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
Expand All @@ -16,6 +17,10 @@

import java.util.Map;

/**
* Contains any model-specific settings that are stored in SageMakerTaskSettings.
* Because TaskSettings are updatable, this object must be able to mutate itself, which we handle through the {@link Builder}.
*/
public interface SageMakerStoredTaskSchema extends ToXContentFragment, VersionedNamedWriteable {
SageMakerStoredTaskSchema NO_OP = new SageMakerStoredTaskSchema() {

Expand Down Expand Up @@ -44,7 +49,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current(); // TODO
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand All @@ -60,9 +65,17 @@ default SageMakerStoredTaskSchema update(Map<String, Object> map, ValidationExce
return toBuilder().fromMap(map, exception).build();
}

/**
* This is called during {@link #update(Map, ValidationException)}.
* Implementations should set the current field values in the Builder, as the update function is expected to overwrite them.
*/
Builder toBuilder();

interface Builder {
/**
* The map will either come from the PUT request or the stored value in the model index.
* It must match the map written by toXContent.
*/
Builder fromMap(Map<String, Object> map, ValidationException exception);

SageMakerStoredTaskSchema build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
import java.util.concurrent.Flow;
import java.util.function.BiFunction;

/**
* All the logic that is required to call any SageMaker model is handled within this Schema class.
* Any model-specific logic is handled within the associated {@link SageMakerStreamSchemaPayload}.
* This schema is specific for SageMaker's streaming API. For non-streaming, see {@link SageMakerSchema}.
*/
public class SageMakerStreamSchema extends SageMakerSchema {

private final SageMakerStreamSchemaPayload payload;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.services.sagemaker.schema;

import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;

import software.amazon.awssdk.core.SdkBytes;

import org.elasticsearch.inference.InferenceServiceResults;
Expand All @@ -16,13 +18,12 @@

import java.util.EnumSet;

/**
* Implemented for models that support streaming.
* This is an extension of {@link SageMakerSchemaPayload} because Elastic expects Completion tasks to handle both streaming and
* non-streaming, and all models currently support toggling streaming on/off.
*/
public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload {
InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;

SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception;

InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;

/**
* We currently only support streaming for Completion and Chat Completion, and if we are going to implement one then we should implement
* the other, so this interface requires both streaming input and streaming unified input.
Expand All @@ -33,4 +34,14 @@ public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload {
default EnumSet<TaskType> supportedTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

/**
* This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would
* handle the request translation for both streaming and non-streaming.
*/
InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;

SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception;

InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;

import org.elasticsearch.TransportVersions;

import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;

Expand Down Expand Up @@ -132,7 +134,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand Down Expand Up @@ -180,7 +182,7 @@ public String getWriteableName() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}

@Override
Expand Down
Loading