Skip to content

[8.19] [ML] Remove error parsing functionality for custom service (#128778) #129638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) {
}
}

private static ResponseHandler createCustomHandler(CustomModel model) {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser());
private static ResponseHandler createCustomHandler() {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse);
}

public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) {
Expand All @@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool)
private CustomRequestManager(CustomModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
this.model = model;
this.handler = createCustomHandler(model);
this.handler = createCustomHandler();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,34 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;

import java.nio.charset.StandardCharsets;
import java.util.function.Function;

/**
* Defines how to handle various response types returned from the custom integration.
*/
public class CustomResponseHandler extends BaseResponseHandler {
public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) {
super(requestType, parseFunction, errorParser);
// default for testing
static final Function<HttpResult, ErrorResponse> ERROR_PARSER = (httpResult) -> {
try {
return new ErrorResponse(new String(httpResult.body(), StandardCharsets.UTF_8));
} catch (Exception e) {
return new ErrorResponse(Strings.format("Failed to parse error response body: %s", e.getMessage()));
}
};

public CustomResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, ERROR_PARSER);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
serviceSettings.getQueryParameters(),
serviceSettings.getRequestContentString(),
serviceSettings.getResponseJsonParser(),
serviceSettings.rateLimitSettings(),
serviceSettings.getErrorParser()
serviceSettings.rateLimitSettings()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
Expand Down Expand Up @@ -59,7 +58,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
public static final String REQUEST = "request";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";

private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
Expand Down Expand Up @@ -100,15 +98,6 @@ public static CustomServiceSettings fromMap(

var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException);

Map<String, Object> errorParserMap = extractRequiredMap(
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
ERROR_PARSER,
RESPONSE_SCOPE,
validationException
);

var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException);

RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
Expand All @@ -117,13 +106,12 @@ public static CustomServiceSettings fromMap(
context
);

if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
if (responseParserMap == null || jsonParserMap == null) {
throw validationException;
}

throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME);
throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME);
throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
Expand All @@ -136,8 +124,7 @@ public static CustomServiceSettings fromMap(
queryParams,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down Expand Up @@ -209,7 +196,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
private final String requestContentString;
private final CustomResponseParser responseJsonParser;
private final RateLimitSettings rateLimitSettings;
private final ErrorResponseParser errorParser;

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
Expand All @@ -218,8 +204,7 @@ public CustomServiceSettings(
@Nullable QueryParameters queryParameters,
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
ErrorResponseParser errorParser
@Nullable RateLimitSettings rateLimitSettings
) {
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
this.url = Objects.requireNonNull(url);
Expand All @@ -228,7 +213,6 @@ public CustomServiceSettings(
this.requestContentString = Objects.requireNonNull(requestContentString);
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.errorParser = Objects.requireNonNull(errorParser);
}

public CustomServiceSettings(StreamInput in) throws IOException {
Expand All @@ -239,7 +223,11 @@ public CustomServiceSettings(StreamInput in) throws IOException {
requestContentString = in.readString();
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
rateLimitSettings = new RateLimitSettings(in);
errorParser = new ErrorResponseParser(in);
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19)) {
// Read the error parsing fields for backwards compatibility
in.readString();
in.readString();
}
}

@Override
Expand Down Expand Up @@ -287,10 +275,6 @@ public CustomResponseParser getResponseJsonParser() {
return responseJsonParser;
}

public ErrorResponseParser getErrorParser() {
return errorParser;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
Expand Down Expand Up @@ -331,7 +315,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
builder.startObject(RESPONSE);
{
responseJsonParser.toXContent(builder, params);
errorParser.toXContent(builder, params);
}
builder.endObject();

Expand Down Expand Up @@ -359,7 +342,11 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(requestContentString);
out.writeNamedWriteable(responseJsonParser);
rateLimitSettings.writeTo(out);
errorParser.writeTo(out);
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19)) {
// Write empty strings for backwards compatibility for the error parsing fields
out.writeString("");
out.writeString("");
}
}

@Override
Expand All @@ -373,8 +360,7 @@ public boolean equals(Object o) {
&& Objects.equals(queryParameters, that.queryParameters)
&& Objects.equals(requestContentString, that.requestContentString)
&& Objects.equals(responseJsonParser, that.responseJsonParser)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(errorParser, that.errorParser);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}

@Override
Expand All @@ -386,8 +372,7 @@ public int hashCode() {
queryParameters,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
Expand Down Expand Up @@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
QueryParameters.EMPTY,
requestContentString,
responseParser,
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.junit.After;
Expand Down Expand Up @@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
null,
requestContentString,
new RerankResponseParser("$.result.score"),
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

var model = CustomModelTests.createModel(
Expand Down
Loading