Skip to content

[ML] Custom Inference Service #125679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
4f4c603
add inference custom model
Huaixinww Mar 7, 2025
e53b2e4
add unit test
Huaixinww Mar 7, 2025
c593b1a
spotless apply
Huaixinww Mar 7, 2025
0a851c1
add custom validation
Huaixinww Mar 7, 2025
e240f06
xpack core spotless apply
Huaixinww Mar 7, 2025
3ea3053
update commons-lang3's version
Huaixinww Mar 7, 2025
83daf69
Fix compilation after rebase
davidkyle Mar 25, 2025
3cb0cfb
Add missing licences and fix build checks
davidkyle Mar 25, 2025
a3c862c
Remove some unused code
davidkyle Mar 25, 2025
2b7e6fe
Update docs/changelog/125679.yaml
davidkyle Mar 26, 2025
6cc593d
Fix services it
davidkyle Mar 27, 2025
a4630e3
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 7, 2025
95f23f0
Contuing refactor of service settings
jonathan-buttner Apr 9, 2025
014f95b
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 9, 2025
189edba
Moving classes to reflect new structure
jonathan-buttner Apr 9, 2025
4fe3a1f
Refactoring service settings
jonathan-buttner Apr 9, 2025
4ef37f5
Refactoring the request
jonathan-buttner Apr 10, 2025
6bac18b
Adding files to handle generic error response
jonathan-buttner Apr 11, 2025
f644471
Making progress on tests
jonathan-buttner Apr 15, 2025
11cf7cc
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 15, 2025
f962d74
Adding more tests
jonathan-buttner Apr 16, 2025
eb63e8b
Adding more tests
jonathan-buttner Apr 18, 2025
adc3210
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 18, 2025
c9ff298
Adding tests for remaining parsers
jonathan-buttner Apr 21, 2025
de83271
More tests
jonathan-buttner Apr 22, 2025
34df922
Need to address quoted strings
jonathan-buttner Apr 24, 2025
b496732
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 28, 2025
097246b
Adding query parameter handling and tests
jonathan-buttner Apr 28, 2025
e7f6ac5
Adding encoding tests
jonathan-buttner Apr 29, 2025
a8c5241
Fixing embedding dimensions issue and test field names
jonathan-buttner Apr 29, 2025
3df0f70
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 29, 2025
ad55337
Fixing tests
jonathan-buttner Apr 29, 2025
4714fd3
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 29, 2025
d13191c
Removing licenses
jonathan-buttner Apr 29, 2025
12d46d7
Adding custom service tests
jonathan-buttner May 2, 2025
0134346
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
e6fefc4
[CI] Auto commit changes from spotless
elasticsearchmachine May 2, 2025
eef7188
Correcting tranport version number
jonathan-buttner May 2, 2025
83837c8
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 2, 2025
dc02425
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
59f75b9
Cleaning up
jonathan-buttner May 2, 2025
8a82163
Fixing counts
jonathan-buttner May 5, 2025
5f13d28
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125679.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125679
summary: Adding support for generic Inference services
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_BLOCK_8_19 = def(8_841_0_24);
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_X = def(8_841_0_27);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
Expand Down Expand Up @@ -242,6 +243,7 @@ static TransportVersion def(int id) {
public static final TransportVersion WRITE_LOAD_INCLUDES_BUFFER_WRITES = def(9_070_00_0);
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION = def(9_071_0_00);
public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_073_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(22));
assertThat(services.size(), equalTo(23));

var providers = providers(services);

Expand All @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(16));
assertThat(services.size(), equalTo(17));

var providers = providers(services);

Expand All @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
List.of(
"alibabacloud-ai-search",
"cohere",
"custom",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand All @@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"googleaistudio",
"openai",
Expand All @@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

Expand All @@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"custom",
"elastic",
"elasticsearch",
"hugging_face",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
Expand Down Expand Up @@ -154,6 +163,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -165,6 +175,38 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));

namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
CustomResponseParser.class,
SparseEmbeddingResponseParser.NAME,
SparseEmbeddingResponseParser::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
Expand Down Expand Up @@ -395,6 +396,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
protected final ResponseParser parseFunction;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this available so the custom response handler can immediately return on a parse failure instead of retrying.

private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

Expand Down
Loading
Loading