Skip to content

[Inference API] Add "rerank" task type to "elastic" provider #126022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

timgrein
Copy link
Contributor

@timgrein timgrein commented Apr 1, 2025

Implements the rerank task type for the elastic provider.

@elasticsearchmachine elasticsearchmachine added v9.1.0 needs:triage Requires assignment of a team area label labels Apr 1, 2025
@kingherc kingherc added the :SearchOrg/Inference Label for the Search Inference team label Apr 1, 2025
@elasticsearchmachine elasticsearchmachine added Team:SearchOrg Meta label for the Search Org (Enterprise Search) Team:Search - Inference and removed needs:triage Requires assignment of a team area label labels Apr 1, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/search-inference-team (Team:Search - Inference)

@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/search-eng (Team:SearchOrg)

@timgrein timgrein added :ml Machine learning Team:ML Meta label for the ML team labels May 6, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, left some suggestions

this.query = query;
this.documents = documents;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We probably don't need a reference to the uri since we have a reference to the model.

@@ -0,0 +1,79 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to transition away from the request manager pattern to avoid the extra class since all the classes are pretty similar.

Here's an example of how we implemented it for voyageai: #124512

And the rerank usage: https://github.com/elastic/elasticsearch/pull/124512/files#diff-3493deea8c9fd5276917f1f8a9d7f008268c34b61f230885df0643de10f19fffR71

public ExecutableAction create(ElasticInferenceServiceRerankModel model) {
var requestManager = new ElasticInferenceServiceRerankRequestManager(model, serviceComponents, traceContext);
var errorMessage = constructFailedToSendRequestMessage(
String.format(Locale.ROOT, "%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think Strings.format() avoids the need for Locale.ROOT.

@@ -214,6 +214,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK = def(9_048_0_00);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the PR is only targeting 9.1 did we also want to support 8.19? If so we'll need to add another transport version and do the backport dance.

return EMPTY_SETTINGS;
}

Integer topNDocumentsOnly = extractOptionalPositiveInteger(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can omit this class. We've moved the common rerank parameters up to the root level of the request and they're passed in to the infer() call from the InferenceAction class. So I think we'll want the ElasticInferenceService to Override the validateRerankParameters from SenderService to ensure only top n is set.

Here's where that's being called by the SenderService: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java#L158

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
:ml Machine learning >non-issue :SearchOrg/Inference Label for the Search Inference team Team:ML Meta label for the ML team Team:Search - Inference Team:SearchOrg Meta label for the Search Org (Enterprise Search) v9.1.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants