Skip to content

[ML] Custom Inference Service #125679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4f4c603
add inference custom model
Huaixinww Mar 7, 2025
e53b2e4
add unit test
Huaixinww Mar 7, 2025
c593b1a
spotless apply
Huaixinww Mar 7, 2025
0a851c1
add custom validation
Huaixinww Mar 7, 2025
e240f06
xpack core spotless apply
Huaixinww Mar 7, 2025
3ea3053
update commons-lang3's version
Huaixinww Mar 7, 2025
83daf69
Fix compilation after rebase
davidkyle Mar 25, 2025
3cb0cfb
Add missing licences and fix build checks
davidkyle Mar 25, 2025
a3c862c
Remove some unused code
davidkyle Mar 25, 2025
2b7e6fe
Update docs/changelog/125679.yaml
davidkyle Mar 26, 2025
6cc593d
Fix services it
davidkyle Mar 27, 2025
a4630e3
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 7, 2025
95f23f0
Contuing refactor of service settings
jonathan-buttner Apr 9, 2025
014f95b
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 9, 2025
189edba
Moving classes to reflect new structure
jonathan-buttner Apr 9, 2025
4fe3a1f
Refactoring service settings
jonathan-buttner Apr 9, 2025
4ef37f5
Refactoring the request
jonathan-buttner Apr 10, 2025
6bac18b
Adding files to handle generic error response
jonathan-buttner Apr 11, 2025
f644471
Making progress on tests
jonathan-buttner Apr 15, 2025
11cf7cc
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 15, 2025
f962d74
Adding more tests
jonathan-buttner Apr 16, 2025
eb63e8b
Adding more tests
jonathan-buttner Apr 18, 2025
adc3210
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 18, 2025
c9ff298
Adding tests for remaining parsers
jonathan-buttner Apr 21, 2025
de83271
More tests
jonathan-buttner Apr 22, 2025
34df922
Need to address quoted strings
jonathan-buttner Apr 24, 2025
b496732
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 28, 2025
097246b
Adding query parameter handling and tests
jonathan-buttner Apr 28, 2025
e7f6ac5
Adding encoding tests
jonathan-buttner Apr 29, 2025
a8c5241
Fixing embedding dimensions issue and test field names
jonathan-buttner Apr 29, 2025
3df0f70
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 29, 2025
ad55337
Fixing tests
jonathan-buttner Apr 29, 2025
4714fd3
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 29, 2025
d13191c
Removing licenses
jonathan-buttner Apr 29, 2025
12d46d7
Adding custom service tests
jonathan-buttner May 2, 2025
0134346
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
e6fefc4
[CI] Auto commit changes from spotless
elasticsearchmachine May 2, 2025
eef7188
Correcting tranport version number
jonathan-buttner May 2, 2025
83837c8
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 2, 2025
dc02425
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
59f75b9
Cleaning up
jonathan-buttner May 2, 2025
8a82163
Fixing counts
jonathan-buttner May 5, 2025
5f13d28
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 5, 2025
c211d83
Fixing rerank and chat completions
jonathan-buttner May 7, 2025
133ef4e
Missing a few changes
jonathan-buttner May 7, 2025
5c28ee8
Passing request to the error response handler
jonathan-buttner May 7, 2025
be84291
Merge remote-tracking branch 'origin/ml-expose-request-in-error-parse…
jonathan-buttner May 7, 2025
8d1bd22
Adding inference id to error parser log message
jonathan-buttner May 8, 2025
a0984c7
Reverting exposing request to error parsing logic
jonathan-buttner May 8, 2025
4242a37
Refactoring the error parsing logic
jonathan-buttner May 8, 2025
6492cd7
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Making progress on tests
  • Loading branch information
jonathan-buttner committed Apr 15, 2025
commit f64447121f86eaa1c3265b9751267a5b2adb0d1c
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ public static Map<String, Object> extractOptionalMap(
}

/**
* Validates that each value in the map is a {@link String} and returns a new map of Map<String, String>.
* Validates that each value in the map is a {@link String} and returns a new map of {@code Map<String, String>}.
*/
public static Map<String, String> validateMapStringValues(
Map<String, ?> map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public void execute(
var request = new CustomRequest(query, input, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener));
} catch (Exception e) {
// Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction
listener.onFailure(
new ElasticsearchStatusException("Failed to construct the custom service request", RestStatus.BAD_REQUEST, e)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static CustomSecretSettings fromMap(@Nullable Map<String, Object> map) {
throw validationException;
}

return new CustomSecretSettings(Objects.requireNonNullElse(secureStringMap, new HashMap<>()));
return new CustomSecretSettings(secureStringMap);
}

private final Map<String, SecureString> secretParameters;
Expand All @@ -56,7 +56,7 @@ public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
}

public CustomSecretSettings(@Nullable Map<String, SecureString> secretParameters) {
this.secretParameters = Objects.requireNonNullElse(secretParameters, new HashMap<>());
this.secretParameters = Objects.requireNonNullElse(secretParameters, Map.of());
}

public CustomSecretSettings(StreamInput in) throws IOException {
Expand All @@ -71,7 +71,13 @@ public Map<String, SecureString> getSecretParameters() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (secretParameters.isEmpty() == false) {
builder.field(SECRET_PARAMETERS, secretParameters);
builder.startObject(SECRET_PARAMETERS);
{
for (var entry : secretParameters.entrySet()) {
builder.field(entry.getKey(), entry.getValue().toString());
}
}
builder.endObject();
}
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static CustomTaskSettings of(CustomTaskSettings originalSettings, CustomT
private final Map<String, Object> parameters;

public CustomTaskSettings(StreamInput in) throws IOException {
parameters = in.readBoolean() ? in.readGenericMap() : null;
parameters = in.readGenericMap();
}

public CustomTaskSettings(Map<String, Object> parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

public class CustomRequest implements Request {
/**
* This regex pattern matches on the string "${<any characters>}"
* This regex pattern matches on the string {@code "${<any characters>}"}
*/
private static final Pattern VARIABLE_PLACEHOLDER_PATTERN = Pattern.compile("\\$\\{.*?\\}");

Expand Down Expand Up @@ -71,7 +71,7 @@ private static <T> String toJson(T value, String field) {
builder.value(value);
return Strings.toString(builder);
} catch (IOException e) {
throw new IllegalStateException(Strings.format("Failed to serialize custom request value as json, field: %s"), e);
throw new IllegalStateException(Strings.format("Failed to serialize custom request value as json, field: %s", field), e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
Expand Down Expand Up @@ -46,10 +47,14 @@ public static RerankResponseParser fromMap(Map<String, Object> responseParserMap
return new RerankResponseParser(relevanceScore, rerankIndex, documentText);
}

public RerankResponseParser(String relevanceScorePath, String rerankIndexPath, String documentTextPath) {
public RerankResponseParser(String relevanceScorePath) {
this(relevanceScorePath, null, null);
}

public RerankResponseParser(String relevanceScorePath, @Nullable String rerankIndexPath, @Nullable String documentTextPath) {
this.relevanceScorePath = Objects.requireNonNull(relevanceScorePath);
this.rerankIndexPath = Objects.requireNonNull(rerankIndexPath);
this.documentTextPath = Objects.requireNonNull(documentTextPath);
this.rerankIndexPath = rerankIndexPath;
this.documentTextPath = documentTextPath;
}

public RerankResponseParser(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,8 @@ public static void assertJsonEquals(String actual, String expected) throws IOExc
assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered()));
}
}

public static <K, V> Map<K, V> modifiableMap(Map<K, V> aMap) {
return new HashMap<>(aMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.inference.Utils.modifiableMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
Expand Down Expand Up @@ -919,8 +920,4 @@ public void testValidateInputType_ValidationErrorsWhenInputTypeIsSpecified() {
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(InputType.CLUSTERING, validationException);
assertThat(validationException.validationErrors().size(), is(4));
}

private static <K, V> Map<K, V> modifiableMap(Map<K, V> aMap) {
return new HashMap<>(aMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@

package org.elasticsearch.xpack.inference.services.custom;

import io.netty.handler.codec.http.HttpMethod;

import org.apache.http.HttpHeaders;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;

import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.is;
Expand All @@ -28,7 +28,7 @@ public class CustomModelTests extends ESTestCase {
public static String taskSettingsValue = "test_taskSettings_value";

public static String secretSettingsKey = "test_secret_key";
public static String secretSettingsValue = "test_secret_value";
public static SecureString secretSettingsValue = new SecureString("test_secret_value".toCharArray());
public static String url = "http://www.abc.com";
public static String path = "/endpoint";

Expand Down Expand Up @@ -66,44 +66,30 @@ public static CustomModel createModel(
}

public static CustomModel getTestModel() {
TaskType taskType = TaskType.TEXT_EMBEDDING;
Map<String, Object> jsonParserMap = new HashMap<>(
Map.of(CustomServiceSettings.TEXT_EMBEDDING_PARSER_EMBEDDINGS, "$.result.embeddings[*].embedding")
);
return getTestModel(taskType, jsonParserMap);
return getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"));
}

public static CustomModel getTestModel(TaskType taskType, Map<String, Object> jsonParserMap) {
// service settings
public static CustomModel getTestModel(TaskType taskType, ResponseParser responseParser) {
Integer dims = 1536;
Integer maxInputTokens = 512;
String description = "test fromMap";
String version = "v1";
String serviceType = taskType.toString();
String method = HttpMethod.POST.name();
String queryString = "?query=${" + taskSettingsKey + "}";
Map<String, Object> headers = Map.of(HttpHeaders.AUTHORIZATION, "${" + secretSettingsKey + "}");
Map<String, String> headers = Map.of(HttpHeaders.AUTHORIZATION, "${" + secretSettingsKey + "}");
String requestContentString = "\"input\":\"${input}\"";

ResponseJsonParser responseJsonParser = new ResponseJsonParser(taskType, jsonParserMap, new ValidationException());

CustomServiceSettings serviceSettings = new CustomServiceSettings(
SimilarityMeasure.DOT_PRODUCT,
dims,
maxInputTokens,
url,
path,
method,
queryString,
headers,
null,
requestContentString,
responseJsonParser,
new RateLimitSettings(10_000)
responseParser,
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message")
);

// task settings
CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue), false);
CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue));

// secret settings
CustomSecretSettings secretSettings = new CustomSecretSettings(Map.of(secretSettingsKey, secretSettingsValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.hamcrest.MatcherAssert;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -22,11 +24,14 @@
import static org.elasticsearch.core.Tuple.tuple;
import static org.hamcrest.Matchers.is;

public class CustomSecretSettingsTests extends AbstractWireSerializingTestCase<CustomSecretSettings> {
public class CustomSecretSettingsTests extends AbstractBWCWireSerializationTestCase<CustomSecretSettings> {
public static CustomSecretSettings createRandom() {
var secretParameters = randomBoolean()
? randomMap(0, 5, () -> tuple(randomAlphaOfLength(5), (Object) randomAlphaOfLength(5)))
: null;
Map<String, SecureString> secretParameters = randomMap(
0,
5,
() -> tuple(randomAlphaOfLength(5), new SecureString(randomAlphaOfLength(5).toCharArray()))
);

return new CustomSecretSettings(secretParameters);
}

Expand All @@ -35,20 +40,28 @@ public void testFromMap() {
Map.of(CustomSecretSettings.SECRET_PARAMETERS, new HashMap<>(Map.of("test_key", "test_value")))
);

MatcherAssert.assertThat(
assertThat(
CustomSecretSettings.fromMap(secretParameters),
is(new CustomSecretSettings(Map.of("test_key", "test_value")))
is(new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))))
);
}

public void testXContent() throws IOException {
var entity = new CustomSecretSettings(Map.of("test_key", "test_value"));
var entity = new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray())));

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);

assertThat(xContentResult, is("{\"secret_parameters\":{\"test_key\":\"test_value\"}}"));
var expected = XContentHelper.stripWhitespace("""
{
"secret_parameters": {
"test_key": "test_value"
}
}
""");

assertThat(xContentResult, is(expected));
}

@Override
Expand All @@ -63,6 +76,11 @@ protected CustomSecretSettings createTestInstance() {

@Override
protected CustomSecretSettings mutateInstance(CustomSecretSettings instance) {
return null;
return randomValueOtherThan(instance, CustomSecretSettingsTests::createRandom);
}

@Override
protected CustomSecretSettings mutateInstanceForVersion(CustomSecretSettings instance, TransportVersion version) {
return instance;
}
}
Loading