Skip to content

[ML] Custom Inference Service #125679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4f4c603
add inference custom model
Huaixinww Mar 7, 2025
e53b2e4
add unit test
Huaixinww Mar 7, 2025
c593b1a
spotless apply
Huaixinww Mar 7, 2025
0a851c1
add custom validation
Huaixinww Mar 7, 2025
e240f06
xpack core spotless apply
Huaixinww Mar 7, 2025
3ea3053
update commons-lang3's version
Huaixinww Mar 7, 2025
83daf69
Fix compilation after rebase
davidkyle Mar 25, 2025
3cb0cfb
Add missing licences and fix build checks
davidkyle Mar 25, 2025
a3c862c
Remove some unused code
davidkyle Mar 25, 2025
2b7e6fe
Update docs/changelog/125679.yaml
davidkyle Mar 26, 2025
6cc593d
Fix services it
davidkyle Mar 27, 2025
a4630e3
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 7, 2025
95f23f0
Contuing refactor of service settings
jonathan-buttner Apr 9, 2025
014f95b
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 9, 2025
189edba
Moving classes to reflect new structure
jonathan-buttner Apr 9, 2025
4fe3a1f
Refactoring service settings
jonathan-buttner Apr 9, 2025
4ef37f5
Refactoring the request
jonathan-buttner Apr 10, 2025
6bac18b
Adding files to handle generic error response
jonathan-buttner Apr 11, 2025
f644471
Making progress on tests
jonathan-buttner Apr 15, 2025
11cf7cc
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 15, 2025
f962d74
Adding more tests
jonathan-buttner Apr 16, 2025
eb63e8b
Adding more tests
jonathan-buttner Apr 18, 2025
adc3210
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 18, 2025
c9ff298
Adding tests for remaining parsers
jonathan-buttner Apr 21, 2025
de83271
More tests
jonathan-buttner Apr 22, 2025
34df922
Need to address quoted strings
jonathan-buttner Apr 24, 2025
b496732
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 28, 2025
097246b
Adding query parameter handling and tests
jonathan-buttner Apr 28, 2025
e7f6ac5
Adding encoding tests
jonathan-buttner Apr 29, 2025
a8c5241
Fixing embedding dimensions issue and test field names
jonathan-buttner Apr 29, 2025
3df0f70
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 29, 2025
ad55337
Fixing tests
jonathan-buttner Apr 29, 2025
4714fd3
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 29, 2025
d13191c
Removing licenses
jonathan-buttner Apr 29, 2025
12d46d7
Adding custom service tests
jonathan-buttner May 2, 2025
0134346
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
e6fefc4
[CI] Auto commit changes from spotless
elasticsearchmachine May 2, 2025
eef7188
Correcting tranport version number
jonathan-buttner May 2, 2025
83837c8
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 2, 2025
dc02425
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
59f75b9
Cleaning up
jonathan-buttner May 2, 2025
8a82163
Fixing counts
jonathan-buttner May 5, 2025
5f13d28
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 5, 2025
c211d83
Fixing rerank and chat completions
jonathan-buttner May 7, 2025
133ef4e
Missing a few changes
jonathan-buttner May 7, 2025
5c28ee8
Passing request to the error response handler
jonathan-buttner May 7, 2025
be84291
Merge remote-tracking branch 'origin/ml-expose-request-in-error-parse…
jonathan-buttner May 7, 2025
8d1bd22
Adding inference id to error parser log message
jonathan-buttner May 8, 2025
a0984c7
Reverting exposing request to error parsing logic
jonathan-buttner May 8, 2025
4242a37
Refactoring the error parsing logic
jonathan-buttner May 8, 2025
6492cd7
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactoring the request
  • Loading branch information
jonathan-buttner committed Apr 10, 2025
commit 4ef37f5e9f7cd9a786e4cb302ae8bd09c17c0ce1
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
Expand Down Expand Up @@ -153,7 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -163,6 +164,32 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
CustomServiceSettings.NAME,
CustomServiceSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
CustomTaskSettings.NAME,
CustomTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
SecretSettings.class,
CustomSecretSettings.NAME,
CustomSecretSettings::new
)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down Expand Up @@ -665,11 +692,4 @@ private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> nam
)
);
}

private static void addCustomWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -479,29 +480,86 @@ public static Map<String, Object> extractOptionalMap(
}

/**
* Ensures that each value in the map is a {@link String}.
* @param map a Map to iterate over
* @param settingName the setting name that his map corresponds to
* @param validationException aggregated validation exceptions
* Ensures the values of the map match one of the supplied types.
* @param map Map to validate
* @param allowedTypes List of {@link Class} to accept
* @param settingName the setting name for the field
* @param validationException exception to return if one of the values is invalid
* @param censorValue if true the key and value will be included in the exception message
*/
public static void validateMapValueStrings(
public static void validateMapValues(
Map<String, Object> map,
List<Class<?>> allowedTypes,
String settingName,
ValidationException validationException
ValidationException validationException,
boolean censorValue
) {
if (map == null) {
return;
}

for (var entry : map.entrySet()) {
var value = entry.getValue();
if (value instanceof String == false) {
validationException.addValidationError(ServiceUtils.invalidTypeErrorMsg(settingName, value, String.class.getSimpleName()));
boolean isAllowed = false;

for (Class<?> allowedType : allowedTypes) {
if (allowedType.isInstance(entry.getValue())) {
isAllowed = true;
break;
}
}

Function<String[], String> errorMessage = (String[] validTypesAsStrings) -> {
if (censorValue) {
return Strings.format(
"Map field [%s] has an entry that is not valid. Value type is not one of [%s].",
settingName,
String.join(", ", validTypesAsStrings)
);
} else {
return Strings.format(
"Map field [%s] has an entry that is not valid, [%s => %s]. Value type is not one of [%s].",
settingName,
entry.getKey(),
entry.getValue(),
String.join(", ", validTypesAsStrings)
);
}
};

if (isAllowed == false) {
var validTypesAsStrings = allowedTypes.stream().map(Class::toString).toArray(String[]::new);
Arrays.sort(validTypesAsStrings);

validationException.addValidationError(errorMessage.apply(validTypesAsStrings));
throw validationException;
}
}
}

public static void convertMapStringsToSecureString(Map<String, Object> map) {
if (map == null) {
return;
}

for (var entry : map.entrySet()) {
var value = entry.getValue();
if (value instanceof String) {
map.put(entry.getKey(), new SecureString(((String) value).toCharArray()));
}
}
}

/**
* Removes null values.
*/
public static void removeNullValues(Map<String, Object> map) {
if (map == null) {
return;
}

map.values().removeIf(Objects::isNull);
}

public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
Map<String, Object> map,
String settingName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues;

public class CustomSecretSettings implements SecretSettings {
public static final String NAME = "custom_secret_settings";
Expand All @@ -36,20 +40,15 @@ public static CustomSecretSettings fromMap(@Nullable Map<String, Object> map) {
ValidationException validationException = new ValidationException();

Map<String, Object> requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException);
removeNullValues(requestSecretParamsMap);
validateMapValues(requestSecretParamsMap, List.of(String.class), SECRET_PARAMETERS, validationException, true);
convertMapStringsToSecureString(requestSecretParamsMap);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

if (requestSecretParamsMap == null) {
return null;
} else {
Map<String, Object> secureSecretParameters = new HashMap<>();
for (String paramKey : requestSecretParamsMap.keySet()) {
Object paramValue = requestSecretParamsMap.get(paramKey);
secureSecretParameters.put(paramKey, paramValue);
}
return new CustomSecretSettings(secureSecretParameters);
}
return new CustomSecretSettings(Objects.requireNonNullElse(requestSecretParamsMap, new HashMap<>()));
}

@Override
Expand All @@ -58,15 +57,11 @@ public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
}

public CustomSecretSettings(@Nullable Map<String, Object> secretParameters) {
this.secretParameters = secretParameters;
this.secretParameters = Objects.requireNonNullElse(secretParameters, new HashMap<>());
}

public CustomSecretSettings(StreamInput in) throws IOException {
if (in.readBoolean()) {
secretParameters = in.readGenericMap();
} else {
secretParameters = null;
}
secretParameters = in.readGenericMap();
}

public Map<String, Object> getSecretParameters() {
Expand All @@ -76,7 +71,7 @@ public Map<String, Object> getSecretParameters() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (secretParameters != null) {
if (secretParameters.isEmpty() == false) {
builder.field(SECRET_PARAMETERS, secretParameters);
}
builder.endObject();
Expand All @@ -95,12 +90,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
if (secretParameters != null) {
out.writeBoolean(true);
out.writeGenericMap(secretParameters);
} else {
out.writeBoolean(false);
}
out.writeGenericMap(secretParameters);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand All @@ -43,8 +44,9 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValueStrings;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues;

public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {
public static final String NAME = "custom_service_settings";
Expand Down Expand Up @@ -84,7 +86,7 @@ private record Fields(
RateLimitSettings rateLimitSettings
) {
public void validate(ValidationException validationException) {
validateMapValueStrings(headers, HEADERS, validationException);
validateMapValues(headers, List.of(String.class), HEADERS, validationException, false);

if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null) {
throw validationException;
Expand Down Expand Up @@ -113,6 +115,7 @@ private static Fields from(
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);

Map<String, Object> headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException);
removeNullValues(headers);

Map<String, Object> requestBodyMap = extractRequiredMap(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException);

Expand Down Expand Up @@ -165,10 +168,10 @@ private static Fields from(
private static CustomServiceSettings fromRequestMap(Map<String, Object> map, TaskType taskType) {
ValidationException validationException = new ValidationException();

var serviceSettings = from(map, ConfigurationParseContext.REQUEST, taskType, validationException);
var serviceSettingsFields = from(map, ConfigurationParseContext.REQUEST, taskType, validationException);

serviceSettings.validate(validationException);
return CustomServiceSettings.of(serviceSettings);
serviceSettingsFields.validate(validationException);
return CustomServiceSettings.of(serviceSettingsFields);
}

private static CustomServiceSettings of(Fields fields) {
Expand Down
Loading
Loading