Skip to content

Issue-105420: Fix bug causing incorrect error on force deleting already deleted model #107188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -37,9 +38,12 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.InferenceProcessorInfoExtractor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
Expand All @@ -52,6 +56,9 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
Expand Down Expand Up @@ -146,11 +153,20 @@ static List<String> getModelAliases(ClusterState clusterState, String modelId) {
return modelAliases;
}

private void deleteModel(DeleteTrainedModelAction.Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) {
protected void deleteModel(
DeleteTrainedModelAction.Request request,
ClusterState state,
ActionListener<AcknowledgedResponse> listener
) {
String id = request.getId();
IngestMetadata currentIngestMetadata = state.metadata().getProject().custom(IngestMetadata.TYPE);
Set<String> referencedModels = InferenceProcessorInfoExtractor.getModelIdsFromInferenceProcessors(currentIngestMetadata);

if (modelExists(request.getId()) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! this looks very clean and readable.

listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, request.getId())));
return;
}

if (request.isForce() == false && referencedModels.contains(id)) {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down Expand Up @@ -199,6 +215,39 @@ private void deleteModel(DeleteTrainedModelAction.Request request, ClusterState
}
}

protected boolean modelExists(String modelId) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can avoid the countdown latch and hence blocking the calling thread by using a listener.

You don't have to timeout the call to trainedModelProvider.getTrainedModel() if it does timeout simply out let the error propagate from the call.

    private void modelExists(String modelId, ActionListener<Boolean> listener) {
        trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null,
            ActionListener.wrap(
                model -> listener.onResponse(Boolean.TRUE),
                exception -> {
                    if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceNotFoundException) {
                        listener.onResponse(Boolean.FALSE);
                    } else {
                        listener.onFailure(exception);
                    }
                }
            )
        );
    }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After you make the modelExists() function async with a listener you will need to chain the various processing steps together. The best way to do this is use a SubscribableListener

Here's an example if it being used:
https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java#L128

CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean modelExists = new AtomicBoolean(false);

Comment on lines +218 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid using a latch and requiring a timeout, I think we could replace this function with an actionListener. What do you think?

ActionListener<TrainedModelConfig> trainedModelListener = new ActionListener<>() {
@Override
public void onResponse(TrainedModelConfig config) {
modelExists.set(true);
latch.countDown();
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to retrieve model {}: {}", modelId, e.getMessage(), e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't an error as we are checking the model's existence here. If the model doesn't exist then that is fine and it should be reported back to the caller rather than logged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.error("Failed to retrieve model {}: {}", modelId, e.getMessage(), e);
logger.error("Failed to retrieve model [" + modelId + "]: [" + e.getMessage() + "]", e);

latch.countDown();
}
};

trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null, trainedModelListener);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't pass a parent task to this request, it wont be cancelable. I think it would be better to pass in the task from the masterOperation here.


try {
boolean latchReached = latch.await(5, TimeUnit.SECONDS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want a 5 second timeout here which will throw an exception. If this timeout occurs, I don't think it will be any clearer what happened for the end user.


if (latchReached == false) {
throw new ElasticsearchException("Timeout while waiting for trained model to be retrieved");
}
} catch (InterruptedException e) {
throw new ElasticsearchException("Unexpected exception", e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is not necessary if you take my suggestion but in Java it's best practice in to reset the interrupt flag with Thread.currentThread().interrupt();

}

return modelExists.get();
}

private void forceStopDeployment(String modelId, ActionListener<StopTrainedModelDeploymentAction.Response> listener) {
StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(modelId);
request.setForce(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,35 @@
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.util.Collections;
import java.util.concurrent.TimeUnit;
Expand All @@ -31,12 +51,16 @@
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.getTaskInfoListOfOne;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockClientWithTasksResponse;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockListTasksClient;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TransportDeleteTrainedModelActionTests extends ESTestCase {
Expand Down Expand Up @@ -136,6 +160,76 @@ public void testCancelDownloadTaskCallsOnResponseWithTheCancelResponseWhenATaskE
assertThat(listener.actionGet(TIMEOUT), is(cancelResponse));
}

public void testModelExistsIsTrueWhenModelIsFound() {
TrainedModelProvider trainedModelProvider = mock(TrainedModelProvider.class);
TrainedModelConfig expectedConfig = buildTrainedModelConfig("modelId");

Mockito.doAnswer(invocation -> {
ActionListener<TrainedModelConfig> listener = invocation.getArgument(3);
listener.onResponse(expectedConfig);
return null;
})
.when(trainedModelProvider)
.getTrainedModel(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.<ActionListener<TrainedModelConfig>>any());

TransportDeleteTrainedModelAction transportDeleteTrainedModelAction = createTransportDeleteTrainedModelAction(trainedModelProvider);
boolean modelExists = transportDeleteTrainedModelAction.modelExists("modelId");

assertThat(modelExists, is(Boolean.TRUE));
}

public void testModelExistsIsFalseWhenModelIsNotFound() {
TrainedModelProvider trainedModelProvider = mock(TrainedModelProvider.class);
Exception failureException = new Exception("Failed to retrieve model");

Mockito.doAnswer(invocation -> {
ActionListener<TrainedModelConfig> listener = invocation.getArgument(3);
listener.onFailure(failureException);
return null;
})
.when(trainedModelProvider)
.getTrainedModel(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.<ActionListener<TrainedModelConfig>>any());

TransportDeleteTrainedModelAction transportDeleteTrainedModelAction = createTransportDeleteTrainedModelAction(trainedModelProvider);
boolean modelExists = transportDeleteTrainedModelAction.modelExists("modelId");

assertThat(modelExists, is(Boolean.FALSE));
}

public void testDeleteModelThrowsExceptionWhenModelIsNotFound() {
TrainedModelProvider trainedModelProvider = mock(TrainedModelProvider.class);
Exception failureException = new Exception("Failed to retrieve model");
ClusterState CLUSTER_STATE = ClusterState.builder(new ClusterName("test"))
.metadata(
Metadata.builder()
.put(
IndexMetadata.builder(".my-system").system(true).settings(indexSettings(IndexVersion.current(), 1, 0)).build(),
true
)
.build()
)
.build();
@SuppressWarnings("unchecked")
ActionListener<AcknowledgedResponse> mockedListener = mock(ActionListener.class);

Mockito.doAnswer(invocation -> {
ActionListener<TrainedModelConfig> listener = invocation.getArgument(3);
listener.onFailure(failureException);
return null;
})
.when(trainedModelProvider)
.getTrainedModel(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.<ActionListener<TrainedModelConfig>>any());

TransportDeleteTrainedModelAction transportDeleteTrainedModelAction = createTransportDeleteTrainedModelAction(trainedModelProvider);
transportDeleteTrainedModelAction.deleteModel(new DeleteTrainedModelAction.Request("modelId"), CLUSTER_STATE, mockedListener);

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(mockedListener).onFailure(exceptionCaptor.capture());
Exception capturedException = exceptionCaptor.getValue();
assertThat(capturedException, is(instanceOf(ResourceNotFoundException.class)));
assertThat(capturedException.getMessage(), containsString("Could not find trained model [modelId]"));
}

private static void mockCancelTask(Client client) {
var cluster = client.admin().cluster();
when(cluster.prepareCancelTasks()).thenReturn(new CancelTasksRequestBuilder(client));
Expand All @@ -152,4 +246,43 @@ private static void mockCancelTasksResponse(Client client, ListTasksResponse res
return Void.TYPE;
}).when(client).execute(same(TransportCancelTasksAction.TYPE), any(), any());
}

private TransportDeleteTrainedModelAction createTransportDeleteTrainedModelAction(TrainedModelProvider configProvider) {
TransportService mockTransportService = mock(TransportService.class);
doReturn(threadPool).when(mockTransportService).getThreadPool();
ClusterService mockClusterService = mock(ClusterService.class);
ActionFilters mockFilters = mock(ActionFilters.class);
doReturn(null).when(mockFilters).filters();
Client mockClient = mock(Client.class);
doReturn(null).when(mockClient).settings();
doReturn(threadPool).when(mockClient).threadPool();
InferenceAuditor auditor = mock(InferenceAuditor.class);

return new TransportDeleteTrainedModelAction(
mockTransportService,
mockClusterService,
threadPool,
null,
mockFilters,
null,
configProvider,
auditor,
null
);
}

private static TrainedModelConfig buildTrainedModelConfig(String modelId) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setDescription("trained model config for test")
.setModelId(modelId)
.setModelType(TrainedModelType.TREE_ENSEMBLE)
.setVersion(MlConfigVersion.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setModelSize(0)
.setEstimatedOperations(0)
.setInput(TrainedModelInputTests.createRandomInput())
.build();
}
}