Skip to content

[Inference API] Align Get/Update Inference APIs with index API pattern #124179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/124179.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 124179
summary: Align Get/Update Inference APIs with index API pattern
area: Inference
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RE_REMOVE_MIN_COMPATIBLE_SHARD_NODE = def(9_021_0_00);
public static final TransportVersion UNASSIGENEDINFO_RESHARD_ADDED = def(9_022_0_00);
public static final TransportVersion INCLUDE_INDEX_MODE_IN_GET_DATA_STREAM = def(9_023_0_00);
public static final TransportVersion INFERENCE_API_WILDCARD_ALLOWED = def(9_024_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -28,8 +29,10 @@
import org.elasticsearch.xpack.core.ml.utils.MlStrings;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -204,20 +207,6 @@ public void writeTo(StreamOutput out) throws IOException {
XContentHelper.writeTo(out, contentType);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = new ActionRequestValidationException();
if (MlStrings.isValidId(this.inferenceEntityId) == false) {
validationException.addValidationError(Messages.getMessage(Messages.INVALID_ID, "inference_id", this.inferenceEntityId));
}

if (validationException.validationErrors().isEmpty() == false) {
return validationException;
} else {
return null;
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand All @@ -237,42 +226,65 @@ public int hashCode() {

public static class Response extends ActionResponse implements ToXContentObject {

private final ModelConfigurations model;
private final List<ModelConfigurations> endpoints;

public Response(List<ModelConfigurations> endpoints) {
this.endpoints = endpoints;
}

public Response(ModelConfigurations model) {
this.model = model;
endpoints = new ArrayList<>();
endpoints.add(model);
}

public Response(StreamInput in) throws IOException {
super(in);
model = new ModelConfigurations(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_WILDCARD_ALLOWED)) {
endpoints = in.readCollectionAsList(ModelConfigurations::new);
} else {
endpoints = new ArrayList<>();
endpoints.add(new ModelConfigurations(in));
}
}

public ModelConfigurations getModel() {
return model;
return endpoints.get(0);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
model.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_WILDCARD_ALLOWED)) {
out.writeCollection(endpoints);
} else {
endpoints.get(0).writeTo(out);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return model.toFilteredXContent(builder, params);
builder.startObject();
builder.startArray("endpoints");
for (var endpoint : endpoints) {
if (endpoint != null) {
endpoint.toFilteredXContent(builder, params);
}
}
builder.endArray();
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
return Objects.equals(model, response.model);
return Objects.equals(endpoints, response.endpoints);
}

@Override
public int hashCode() {
return Objects.hash(model);
return Objects.hash(endpoints);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ protected void doExecute(
} else if (inferenceEntityIdIsWildCard) {
getModelsByTaskType(request.getTaskType(), listener);
} else {
getSingleModel(request.getInferenceEntityId(), request.getTaskType(), listener);
if (request.getInferenceEntityId().contains(",") || request.getInferenceEntityId().contains("*")) {
getModelsByTaskTypeAndIds(request.getTaskType(), request.getInferenceEntityId(), listener);
} else {
getSingleModel(request.getInferenceEntityId(), request.getTaskType(), listener);
}
}
}

Expand Down Expand Up @@ -128,6 +132,18 @@ private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceM
);
}

private void getModelsByTaskTypeAndIds(
TaskType taskType,
String inferenceEntityIdExpression,
ActionListener<GetInferenceModelAction.Response> listener
) {
modelRegistry.getModelsByTaskTypeAndInferenceEntityExpression(
taskType,
inferenceEntityIdExpression,
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
);
}

private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
if (unparsedModels.isEmpty()) {
listener.onResponse(new GetInferenceModelAction.Response(List.of()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
Expand Down Expand Up @@ -49,12 +50,15 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.common.InferenceExceptions;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -118,80 +122,91 @@ protected void masterOperation(
var bodyTaskType = request.getContentAsSettings().taskType();
var resolvedTaskType = resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null);

AtomicReference<InferenceService> service = new AtomicReference<>();

var inferenceEntityId = request.getInferenceEntityId();

SubscribableListener.<UnparsedModel>newForked(listener -> { checkEndpointExists(inferenceEntityId, listener); })
.<UnparsedModel>andThen((listener, unparsedModel) -> {

Optional<InferenceService> optionalService = serviceRegistry.getService(unparsedModel.service());
if (optionalService.isEmpty()) {
listener.onFailure(
new ElasticsearchStatusException(
"Service [{}] not found",
RestStatus.INTERNAL_SERVER_ERROR,
unparsedModel.service()
)
);
} else {
service.set(optionalService.get());
listener.onResponse(unparsedModel);
}
})
.<Boolean>andThen((listener, existingUnparsedModel) -> {

Model existingParsedModel = service.get()
.parsePersistedConfigWithSecrets(
request.getInferenceEntityId(),
existingUnparsedModel.taskType(),
new HashMap<>(existingUnparsedModel.settings()),
new HashMap<>(existingUnparsedModel.secrets())
);

Model newModel = combineExistingModelWithNewSettings(
existingParsedModel,
request.getContentAsSettings(),
service.get().name(),
resolvedTaskType
);

if (isInClusterService(service.get().name())) {
updateInClusterEndpoint(request, newModel, existingParsedModel, listener);
} else {
modelRegistry.updateModelTransaction(newModel, existingParsedModel, listener);
}
})
.<ModelConfigurations>andThen((listener, didUpdate) -> {
if (didUpdate) {
modelRegistry.getModel(inferenceEntityId, ActionListener.wrap((unparsedModel) -> {
if (unparsedModel == null) {
listener.onFailure(
new ElasticsearchStatusException(
"Failed to update model, updated model not found",
RestStatus.INTERNAL_SERVER_ERROR
)
);
} else {
listener.onResponse(
service.get()
.parsePersistedConfig(
request.getInferenceEntityId(),
resolvedTaskType,
new HashMap<>(unparsedModel.settings())
SubscribableListener.<List<UnparsedModel>>newForked(listener -> {
checkEndpointsExists(request.getInferenceEntityId(), resolvedTaskType, listener);
}).<List<ModelConfigurations>>andThen((upperListener, unparsedModels) -> {
List<ModelConfigurations> existingUnparsedModels = Collections.synchronizedList(new ArrayList<>());

try (var listeners = new RefCountingListener(upperListener.map(v -> existingUnparsedModels))) {
for (UnparsedModel model : unparsedModels) {
AtomicReference<InferenceService> service = new AtomicReference<>();
String existingInferenceEntityId = model.inferenceEntityId();

SubscribableListener.<UnparsedModel>newForked(listener -> { checkEndpointExists(model.inferenceEntityId(), listener); })
.<UnparsedModel>andThen((listener, unparsedModel) -> {

Optional<InferenceService> optionalService = serviceRegistry.getService(unparsedModel.service());
if (optionalService.isEmpty()) {
listener.onFailure(
new ElasticsearchStatusException(
"Service [{}] not found",
RestStatus.INTERNAL_SERVER_ERROR,
unparsedModel.service()
)
.getConfigurations()
);
} else {
service.set(optionalService.get());
listener.onResponse(unparsedModel);
}
})
.<Boolean>andThen((listener, existingUnparsedModel) -> {

Model existingParsedModel = service.get()
.parsePersistedConfigWithSecrets(
existingInferenceEntityId,
existingUnparsedModel.taskType(),
new HashMap<>(existingUnparsedModel.settings()),
new HashMap<>(existingUnparsedModel.secrets())
);

Model newModel = combineExistingModelWithNewSettings(
existingParsedModel,
request.getContentAsSettings(),
service.get().name(),
resolvedTaskType
);
}
}, listener::onFailure));
} else {
listener.onFailure(new ElasticsearchStatusException("Failed to update model", RestStatus.INTERNAL_SERVER_ERROR));
}

}).<UpdateInferenceModelAction.Response>andThen((listener, modelConfig) -> {
listener.onResponse(new UpdateInferenceModelAction.Response(modelConfig));
})
.addListener(masterListener);
if (isInClusterService(service.get().name())) {
updateInClusterEndpoint(existingInferenceEntityId, request, newModel, existingParsedModel, listener);
} else {
modelRegistry.updateModelTransaction(newModel, existingParsedModel, listener);
}
})
.<ModelConfigurations>andThen((listener, didUpdate) -> {
if (didUpdate) {
modelRegistry.getModel(existingInferenceEntityId, ActionListener.wrap((unparsedModel) -> {
if (unparsedModel == null) {
listener.onFailure(
new ElasticsearchStatusException(
"Failed to update model, updated model not found",
RestStatus.INTERNAL_SERVER_ERROR
)
);
} else {
listener.onResponse(
service.get()
.parsePersistedConfig(
existingInferenceEntityId,
resolvedTaskType,
new HashMap<>(unparsedModel.settings())
)
.getConfigurations()
);
}
}, listener::onFailure));
} else {
listener.onFailure(
new ElasticsearchStatusException("Failed to update model", RestStatus.INTERNAL_SERVER_ERROR)
);
}

})
.addListener(listeners.acquire(existingUnparsedModels::add));
}
}
}).<UpdateInferenceModelAction.Response>andThen((listener, existingUnparsedModels) -> {
listener.onResponse(new UpdateInferenceModelAction.Response(existingUnparsedModels));
}).addListener(masterListener);
}

/**
Expand Down Expand Up @@ -249,6 +264,7 @@ private Model combineExistingModelWithNewSettings(
}

private void updateInClusterEndpoint(
String inferenceEntityId,
UpdateInferenceModelAction.Request request,
Model newModel,
Model existingParsedModel,
Expand All @@ -271,7 +287,7 @@ private void updateInClusterEndpoint(
logger.info(
"Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations",
deploymentId,
request.getInferenceEntityId(),
inferenceEntityId,
numAllocations
);
client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, updateRequest, delegate);
Expand Down Expand Up @@ -339,6 +355,22 @@ private void checkEndpointExists(String inferenceEntityId, ActionListener<Unpars
}));
}

private void checkEndpointsExists(String inferenceEntityId, TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
if (Strings.isAllOrWildcard(inferenceEntityId)) {
modelRegistry.getModelsByTaskType(taskType, listener);
} else if (inferenceEntityId.contains(",") || inferenceEntityId.contains("*")) {
modelRegistry.getModelsByTaskTypeAndInferenceEntityExpression(taskType, inferenceEntityId, listener);
} else {
modelRegistry.getModel(inferenceEntityId, listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
if (taskType.isAnyOrSame(unparsedModel.taskType()) == false) {
delegate.onFailure(InferenceExceptions.mismatchedTaskTypeException(taskType, unparsedModel.taskType()));
return;
}
listener.onResponse(List.of(unparsedModel));
}));
}
}

private static XContentParser getParser(UpdateInferenceModelAction.Request request) throws IOException {
return XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType());
}
Expand Down
Loading