Skip to content

[ML] Custom service adding support for the semantic text field #129558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ static TransportVersion def(int id) {
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -304,6 +305,7 @@ static TransportVersion def(int id) {
public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00);
public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -51,6 +52,27 @@ public CustomModel(
);
}

public CustomModel(
String inferenceId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets,
@Nullable ChunkingSettings chunkingSettings,
ConfigurationParseContext context
) {
this(
inferenceId,
taskType,
service,
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
CustomTaskSettings.fromMap(taskSettings),
CustomSecretSettings.fromMap(secrets),
chunkingSettings
);
}

// should only be used for testing
CustomModel(
String inferenceId,
Expand All @@ -67,6 +89,23 @@ public CustomModel(
);
}

// should only be used for testing
CustomModel(
String inferenceId,
TaskType taskType,
String service,
CustomServiceSettings serviceSettings,
CustomTaskSettings taskSettings,
@Nullable CustomSecretSettings secretSettings,
@Nullable ChunkingSettings chunkingSettings
) {
this(
new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings
);
}

protected CustomModel(CustomModel model, TaskSettings taskSettings) {
super(model, taskSettings);
rateLimitServiceSettings = model.rateLimitServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand All @@ -27,6 +28,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -45,6 +48,7 @@
import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
Expand Down Expand Up @@ -81,12 +85,15 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);

CustomModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
chunkingSettings,
ConfigurationParseContext.REQUEST
);

Expand All @@ -100,6 +107,14 @@ public void parseRequestConfig(
}
}

private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return null;
}

@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
Expand All @@ -125,14 +140,16 @@ private static CustomModel createModelWithoutLoggingDeprecations(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings
@Nullable Map<String, Object> secretSettings,
@Nullable ChunkingSettings chunkingSettings
) {
return createModel(
inferenceEntityId,
taskType,
serviceSettings,
taskSettings,
secretSettings,
chunkingSettings,
ConfigurationParseContext.PERSISTENT
);
}
Expand All @@ -143,12 +160,13 @@ private static CustomModel createModel(
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings,
@Nullable ChunkingSettings chunkingSettings,
ConfigurationParseContext context
) {
if (supportedTaskTypes.contains(taskType) == false) {
throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
}
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
}

@Override
Expand All @@ -162,15 +180,33 @@ public CustomModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
var chunkingSettings = extractChunkingSettings(config, taskType);

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap,
chunkingSettings
);
}

@Override
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
var chunkingSettings = extractChunkingSettings(config, taskType);

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
null,
chunkingSettings
);
}

@Override
Expand Down Expand Up @@ -211,7 +247,27 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
if (model instanceof CustomModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

var customModel = (CustomModel) model;
var overriddenModel = CustomModel.of(customModel, taskSettings);

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME);
var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
customModel.getServiceSettings().getBatchSize(),
customModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
Expand All @@ -52,15 +53,18 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;

public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {

public static final String NAME = "custom_service_settings";
public static final String URL = "url";
public static final String BATCH_SIZE = "batch_size";
public static final String HEADERS = "headers";
public static final String REQUEST = "request";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";

private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10;

public static CustomServiceSettings fromMap(
Map<String, Object> map,
Expand Down Expand Up @@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap(
context
);

var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException);

if (responseParserMap == null || jsonParserMap == null) {
throw validationException;
}
Expand All @@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap(
queryParams,
requestContentString,
responseJsonParser,
rateLimitSettings
rateLimitSettings,
batchSize
);
}

Expand All @@ -142,7 +149,6 @@ public record TextEmbeddingSettings(
null,
DenseVectorFieldMapper.ElementType.FLOAT
);

// This refers to settings that are not related to the text embedding task type (all the settings should be null)
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);

Expand Down Expand Up @@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
private final String requestContentString;
private final CustomResponseParser responseJsonParser;
private final RateLimitSettings rateLimitSettings;
private final int batchSize;

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
Expand All @@ -205,6 +212,19 @@ public CustomServiceSettings(
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings
) {
this(textEmbeddingSettings, url, headers, queryParameters, requestContentString, responseJsonParser, rateLimitSettings, null);
}

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
String url,
@Nullable Map<String, String> headers,
@Nullable QueryParameters queryParameters,
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
@Nullable Integer batchSize
) {
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
this.url = Objects.requireNonNull(url);
Expand All @@ -213,6 +233,7 @@ public CustomServiceSettings(
this.requestContentString = Objects.requireNonNull(requestContentString);
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE);
}

public CustomServiceSettings(StreamInput in) throws IOException {
Expand All @@ -223,12 +244,20 @@ public CustomServiceSettings(StreamInput in) throws IOException {
requestContentString = in.readString();
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
rateLimitSettings = new RateLimitSettings(in);

if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Read the error parsing fields for backwards compatibility
in.readString();
in.readString();
}

if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE)
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
batchSize = in.readVInt();
} else {
batchSize = DEFAULT_EMBEDDING_BATCH_SIZE;
}
}

@Override
Expand Down Expand Up @@ -276,6 +305,10 @@ public CustomResponseParser getResponseJsonParser() {
return responseJsonParser;
}

public int getBatchSize() {
return batchSize;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
Expand Down Expand Up @@ -321,6 +354,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder

rateLimitSettings.toXContent(builder, params);

builder.field(BATCH_SIZE, batchSize);

return builder;
}

Expand All @@ -343,12 +378,18 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(requestContentString);
out.writeNamedWriteable(responseJsonParser);
rateLimitSettings.writeTo(out);

if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Write empty strings for backwards compatibility for the error parsing fields
out.writeString("");
out.writeString("");
}

if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE)
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
out.writeVInt(batchSize);
}
}

@Override
Expand All @@ -362,7 +403,8 @@ public boolean equals(Object o) {
&& Objects.equals(queryParameters, that.queryParameters)
&& Objects.equals(requestContentString, that.requestContentString)
&& Objects.equals(responseJsonParser, that.responseJsonParser)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(batchSize, that.batchSize);
}

@Override
Expand All @@ -374,7 +416,8 @@ public int hashCode() {
queryParameters,
requestContentString,
responseJsonParser,
rateLimitSettings
rateLimitSettings,
batchSize
);
}

Expand Down
Loading