Skip to content

Commit 16538d6

Browse files
authored
[ML] Directly call Inference API from Proxy (#127342) (#127396)
In order to propagate response headers back from the proxied actions, we are directly calling the Transport actions via the NodeClient.
1 parent c00a8fc commit 16538d6

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java

+20-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.ActionRequest;
13+
import org.elasticsearch.action.ActionResponse;
14+
import org.elasticsearch.action.ActionType;
1215
import org.elasticsearch.action.support.ActionFilters;
16+
import org.elasticsearch.action.support.ContextPreservingActionListener;
1317
import org.elasticsearch.action.support.HandledTransportAction;
1418
import org.elasticsearch.client.internal.Client;
1519
import org.elasticsearch.common.util.concurrent.EsExecutors;
20+
import org.elasticsearch.common.util.concurrent.ThreadContext;
1621
import org.elasticsearch.common.xcontent.XContentHelper;
1722
import org.elasticsearch.inference.TaskType;
1823
import org.elasticsearch.inference.UnparsedModel;
@@ -30,7 +35,6 @@
3035
import java.io.IOException;
3136

3237
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
33-
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
3438

3539
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
3640
private final ModelRegistry modelRegistry;
@@ -103,7 +107,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request,
103107
);
104108
}
105109

106-
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener);
110+
execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
107111
} catch (Exception e) {
108112
unifiedErrorFormatListener.onFailure(e);
109113
}
@@ -122,6 +126,19 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac
122126
inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
123127
}
124128

125-
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
129+
execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
130+
}
131+
132+
private <Request extends ActionRequest, Response extends ActionResponse> void execute(
133+
ActionType<Response> action,
134+
Request request,
135+
ActionListener<Response> listener
136+
) {
137+
var threadContext = client.threadPool().getThreadContext();
138+
// stash the context so we clear the user's security headers, then restore and copy the response headers
139+
var supplier = threadContext.newRestorableContext(true);
140+
try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(INFERENCE_ORIGIN)) {
141+
client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener));
142+
}
126143
}
127144
}

0 commit comments

Comments
 (0)