Skip to content

Forward port changes from backport of #125562 #126413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Forward port changes from backport of #125562
The backport to `8.x` needed some changes to pass through CI; this
commit forward-ports the relevant bits of those changes back into `main`
to keep the branches aligned.
  • Loading branch information
DaveCTurner committed Apr 7, 2025
commit c0bb2608291176298c037e8b5f73de32b2774e30
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Iterator;

import static org.elasticsearch.test.ESTestCase.asInstanceOf;
import static org.elasticsearch.test.ESTestCase.fail;
import static org.elasticsearch.transport.BytesRefRecycler.NON_RECYCLING_INSTANCE;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

public class RestResponseUtils {
private RestResponseUtils() {}
Expand Down Expand Up @@ -48,7 +53,7 @@ public static BytesReference getBodyContent(RestResponse restResponse) {
out.flush();
return out.bytes();
} catch (Exception e) {
return ESTestCase.fail(e);
return fail(e);
}
}

Expand All @@ -60,7 +65,18 @@ public static String getTextBodyContent(Iterator<CheckedConsumer<Writer, IOExcep
writer.flush();
return writer.toString();
} catch (Exception e) {
return ESTestCase.fail(e);
return fail(e);
}
}

public static <T extends ToXContent> T setUpXContentMock(T mock) {
try {
when(mock.toXContent(any(), any())).thenAnswer(
invocation -> asInstanceOf(XContentBuilder.class, invocation.getArgument(0)).startObject().endObject()
);
} catch (IOException e) {
fail(e);
}
return mock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestActionListener;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.logstash.Pipeline;
import org.elasticsearch.xpack.logstash.action.PutPipelineAction;
import org.elasticsearch.xpack.logstash.action.PutPipelineRequest;
Expand Down Expand Up @@ -55,9 +54,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
new PutPipelineRequest(id, content, request.getXContentType()),
new RestActionListener<>(restChannel) {
@Override
protected void processResponse(PutPipelineResponse putPipelineResponse) throws Exception {
protected void processResponse(PutPipelineResponse putPipelineResponse) {
channel.sendResponse(
new RestResponse(putPipelineResponse.status(), XContentType.JSON.mediaType(), BytesArray.EMPTY)
new RestResponse(putPipelineResponse.status(), RestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;

import java.io.IOException;
import java.util.HashMap;

import static org.elasticsearch.rest.RestResponseUtils.setUpXContentMock;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class RestUpdateTrainedModelDeploymentActionTests extends RestActionTestCase {
public void testNumberOfAllocationInParam() {
Expand All @@ -37,7 +34,7 @@ public void testNumberOfAllocationInParam() {
assertEquals(request.getNumberOfAllocations().intValue(), 5);

executeCalled.set(true);
return newMockResponse();
return setUpXContentMock(mock(CreateTrainedModelAssignmentAction.Response.class));
}));
var params = new HashMap<String, String>();
params.put("number_of_allocations", "5");
Expand All @@ -60,7 +57,7 @@ public void testNumberOfAllocationInBody() {
assertEquals(request.getNumberOfAllocations().intValue(), 6);

executeCalled.set(true);
return newMockResponse();
return setUpXContentMock(mock(CreateTrainedModelAssignmentAction.Response.class));
}));

final String content = """
Expand All @@ -73,16 +70,4 @@ public void testNumberOfAllocationInBody() {
dispatchRequest(inferenceRequest);
assertThat(executeCalled.get(), equalTo(true));
}

private static CreateTrainedModelAssignmentAction.Response newMockResponse() {
final var response = mock(CreateTrainedModelAssignmentAction.Response.class);
try {
when(response.toXContent(any(), any())).thenAnswer(
invocation -> asInstanceOf(XContentBuilder.class, invocation.getArgument(0)).startObject().endObject()
);
} catch (IOException e) {
fail(e);
}
return response;
}
}