diff --git a/docs/changelog/126930.yaml b/docs/changelog/126930.yaml new file mode 100644 index 0000000000000..1507cec38ee02 --- /dev/null +++ b/docs/changelog/126930.yaml @@ -0,0 +1,5 @@ +pr: 126930 +summary: Adding missing `onFailure` call for Inference API start model request +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index f743b94df3810..2599fef169748 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -107,7 +107,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL }) .andThen((l2, modelDidPut) -> { var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout); - var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener); + var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2); client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); }) .addListener(finalListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index aa12bf0c645c3..f1011efd3b12c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -105,6 +105,8 @@ public void onFailure(Exception e) { && statusException.getRootCause() instanceof ResourceAlreadyExistsException) { // Deployment is already started listener.onResponse(Boolean.TRUE); + } else { + listener.onFailure(e); } return; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index d8886e1eea471..7fb9ce8a6c37f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ParseField; @@ -47,13 +48,16 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.ml.MachineLearningField; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -1870,6 +1874,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { } } + public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { + var model = new ElserInternalModel( + "inference_id", + TaskType.SPARSE_EMBEDDING, + "elasticsearch", + new ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null) + ), + new ElserMlNodeTaskSettings(), + null + ); + + var client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(2); + var builder = GetTrainedModelsAction.Response.builder(); + builder.setModels(List.of(mock(TrainedModelConfig.class))); + builder.setTotalCount(1); + + listener.onResponse(builder.build()); + return Void.TYPE; + }).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any()); + + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(2); + listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)); + return Void.TYPE; + }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any()); + + try (var service = createService(client)) { + var actionListener = new PlainActionFuture(); + service.start(model, TimeValue.timeValueSeconds(30), actionListener); + var exception = expectThrows( + ElasticsearchStatusException.class, + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) + ); + + assertThat(exception.getMessage(), is("failed")); + } + } + private ElasticsearchInternalService createService(Client client) { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));