9
9
10
10
import org .elasticsearch .ElasticsearchStatusException ;
11
11
import org .elasticsearch .action .ActionListener ;
12
+ import org .elasticsearch .action .ActionRequest ;
13
+ import org .elasticsearch .action .ActionResponse ;
14
+ import org .elasticsearch .action .ActionType ;
12
15
import org .elasticsearch .action .support .ActionFilters ;
16
+ import org .elasticsearch .action .support .ContextPreservingActionListener ;
13
17
import org .elasticsearch .action .support .HandledTransportAction ;
14
18
import org .elasticsearch .client .internal .Client ;
15
19
import org .elasticsearch .common .util .concurrent .EsExecutors ;
20
+ import org .elasticsearch .common .util .concurrent .ThreadContext ;
16
21
import org .elasticsearch .common .xcontent .XContentHelper ;
17
22
import org .elasticsearch .inference .TaskType ;
18
23
import org .elasticsearch .inference .UnparsedModel ;
30
35
import java .io .IOException ;
31
36
32
37
import static org .elasticsearch .xpack .core .ClientHelper .INFERENCE_ORIGIN ;
33
- import static org .elasticsearch .xpack .core .ClientHelper .executeAsyncWithOrigin ;
34
38
35
39
public class TransportInferenceActionProxy extends HandledTransportAction <InferenceActionProxy .Request , InferenceAction .Response > {
36
40
private final ModelRegistry modelRegistry ;
@@ -103,7 +107,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request,
103
107
);
104
108
}
105
109
106
- executeAsyncWithOrigin ( client , INFERENCE_ORIGIN , UnifiedCompletionAction .INSTANCE , unifiedRequest , unifiedErrorFormatListener );
110
+ execute ( UnifiedCompletionAction .INSTANCE , unifiedRequest , listener );
107
111
} catch (Exception e ) {
108
112
unifiedErrorFormatListener .onFailure (e );
109
113
}
@@ -122,6 +126,19 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac
122
126
inferenceActionRequestBuilder .setInferenceTimeout (request .getTimeout ()).setStream (request .isStreaming ());
123
127
}
124
128
125
- executeAsyncWithOrigin (client , INFERENCE_ORIGIN , InferenceAction .INSTANCE , inferenceActionRequestBuilder .build (), listener );
129
+ execute (InferenceAction .INSTANCE , inferenceActionRequestBuilder .build (), listener );
130
+ }
131
+
132
+ private <Request extends ActionRequest , Response extends ActionResponse > void execute (
133
+ ActionType <Response > action ,
134
+ Request request ,
135
+ ActionListener <Response > listener
136
+ ) {
137
+ var threadContext = client .threadPool ().getThreadContext ();
138
+ // stash the context so we clear the user's security headers, then restore and copy the response headers
139
+ var supplier = threadContext .newRestorableContext (true );
140
+ try (ThreadContext .StoredContext ignore = threadContext .stashWithOrigin (INFERENCE_ORIGIN )) {
141
+ client .execute (action , request , new ContextPreservingActionListener <>(supplier , listener ));
142
+ }
126
143
}
127
144
}
0 commit comments