Skip to content

Remove ChannelType.Client #127432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand All @@ -36,21 +34,17 @@ public class InboundDecoder implements Releasable {
private int bytesConsumed = 0;
private boolean isCompressed = false;
private boolean isClosed = false;
private final ByteSizeValue maxHeaderSize;
private final ChannelType channelType;
private final int maxHeaderSize;
private final boolean isServerChannel;

public InboundDecoder(Recycler<BytesRef> recycler) {
this(recycler, ByteSizeValue.of(2, ByteSizeUnit.GB), ChannelType.MIX);
this(recycler, Integer.MAX_VALUE, false);
}

public InboundDecoder(Recycler<BytesRef> recycler, ChannelType channelType) {
this(recycler, ByteSizeValue.of(2, ByteSizeUnit.GB), channelType);
}

public InboundDecoder(Recycler<BytesRef> recycler, ByteSizeValue maxHeaderSize, ChannelType channelType) {
public InboundDecoder(Recycler<BytesRef> recycler, int maxHeaderSize, boolean isServerChannel) {
this.recycler = recycler;
this.maxHeaderSize = maxHeaderSize;
this.channelType = channelType;
this.isServerChannel = isServerChannel;
}

public int decode(ReleasableBytesReference reference, CheckedConsumer<Object, IOException> fragmentConsumer) throws IOException {
Expand All @@ -73,13 +67,13 @@ public int internalDecode(ReleasableBytesReference reference, CheckedConsumer<Ob
fragmentConsumer.accept(PING);
return 6;
} else {
int headerBytesToRead = headerBytesToRead(reference, maxHeaderSize);
int headerBytesToRead = headerBytesToRead(reference);
if (headerBytesToRead == 0) {
return 0;
} else {
totalNetworkSize = messageLength + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE;

Header header = readHeader(messageLength, reference, channelType);
Header header = readHeader(messageLength, reference);
bytesConsumed += headerBytesToRead;
if (header.isCompressed()) {
isCompressed = true;
Expand Down Expand Up @@ -160,11 +154,7 @@ private boolean isDone() {
return bytesConsumed == totalNetworkSize;
}

private static int headerBytesToRead(BytesReference reference, ByteSizeValue maxHeaderSize) throws StreamCorruptedException {
if (reference.length() < TcpHeader.BYTES_REQUIRED_FOR_VERSION) {
return 0;
}

private int headerBytesToRead(BytesReference reference) throws StreamCorruptedException {
if (reference.length() <= TcpHeader.HEADER_SIZE) {
return 0;
} else {
Expand All @@ -173,7 +163,7 @@ private static int headerBytesToRead(BytesReference reference, ByteSizeValue max
throw new StreamCorruptedException("invalid negative variable header size: " + variableHeaderSize);
}
int totalHeaderSize = TcpHeader.HEADER_SIZE + variableHeaderSize;
if (totalHeaderSize > maxHeaderSize.getBytes()) {
if (totalHeaderSize > maxHeaderSize) {
throw new StreamCorruptedException("header size [" + totalHeaderSize + "] exceeds limit of [" + maxHeaderSize + "]");
}
if (totalHeaderSize > reference.length()) {
Expand All @@ -184,18 +174,16 @@ private static int headerBytesToRead(BytesReference reference, ByteSizeValue max
}
}

private static Header readHeader(int networkMessageSize, BytesReference bytesReference, ChannelType channelType) throws IOException {
private Header readHeader(int networkMessageSize, BytesReference bytesReference) throws IOException {
try (StreamInput streamInput = bytesReference.streamInput()) {
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
long requestId = streamInput.readLong();
byte status = streamInput.readByte();
int remoteVersion = streamInput.readInt();

Header header = new Header(networkMessageSize, requestId, status, TransportVersion.fromId(remoteVersion));
if (channelType == ChannelType.SERVER && header.isResponse()) {
if (isServerChannel && header.isResponse()) {
throw new IllegalArgumentException("server channels do not accept inbound responses, only requests, closing channel");
} else if (channelType == ChannelType.CLIENT && header.isRequest()) {
throw new IllegalArgumentException("client channels do not accept inbound requests, only responses, closing channel");
}
if (header.isHandshake()) {
checkHandshakeVersionCompatibility(header.getVersion());
Expand Down Expand Up @@ -243,9 +231,4 @@ static void checkVersionCompatibility(TransportVersion remoteVersion) {
}
}

public enum ChannelType {
SERVER,
CLIENT,
MIX
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.transport.InboundDecoder.ChannelType;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -160,56 +159,6 @@ private void doHandshakeCompatibilityTest(TransportVersion transportVersion, Com
}
}

public void testClientChannelTypeFailsDecodingRequests() throws Exception {
String action = "test-request";
long requestId = randomNonNegativeLong();
if (randomBoolean()) {
final String headerKey = randomAlphaOfLength(10);
final String headerValue = randomAlphaOfLength(20);
if (randomBoolean()) {
threadContext.putHeader(headerKey, headerValue);
} else {
threadContext.addResponseHeader(headerKey, headerValue);
}
}
// a request
final var isHandshake = randomBoolean();
final var version = isHandshake
? randomFrom(TransportHandshaker.ALLOWED_HANDSHAKE_VERSIONS)
: TransportVersionUtils.randomCompatibleVersion(random());
logger.info("--> version = {}", version);

try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = OutboundHandler.serialize(
OutboundHandler.MessageDirection.REQUEST,
action,
requestId,
isHandshake,
version,
randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4, null),
new TestRequest(randomAlphaOfLength(100)),
threadContext,
os
);
try (InboundDecoder clientDecoder = new InboundDecoder(recycler, ChannelType.CLIENT)) {
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> clientDecoder.decode(wrapAsReleasable(bytes), ignored -> {})
);
assertThat(e.getMessage(), containsString("client channels do not accept inbound requests, only responses"));
}
// the same message will be decoded by a server or mixed decoder
try (InboundDecoder decoder = new InboundDecoder(recycler, randomFrom(ChannelType.SERVER, ChannelType.MIX))) {
final ArrayList<Object> fragments = new ArrayList<>();
int bytesConsumed = decoder.decode(wrapAsReleasable(bytes), fragments::add);
int totalHeaderSize = TcpHeader.HEADER_SIZE + bytes.getInt(TcpHeader.VARIABLE_HEADER_SIZE_POSITION);
assertEquals(totalHeaderSize, bytesConsumed);
final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
}
}
}

public void testServerChannelTypeFailsDecodingResponses() throws Exception {
long requestId = randomNonNegativeLong();
if (randomBoolean()) {
Expand Down Expand Up @@ -239,13 +188,13 @@ public void testServerChannelTypeFailsDecodingResponses() throws Exception {
threadContext,
os
);
try (InboundDecoder decoder = new InboundDecoder(recycler, ChannelType.SERVER)) {
try (InboundDecoder decoder = new InboundDecoder(recycler, Integer.MAX_VALUE, true)) {
final ReleasableBytesReference releasable1 = wrapAsReleasable(bytes);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> decoder.decode(releasable1, ignored -> {}));
assertThat(e.getMessage(), containsString("server channels do not accept inbound responses, only requests"));
}
// the same message will be decoded by a client or mixed decoder
try (InboundDecoder decoder = new InboundDecoder(recycler, randomFrom(ChannelType.CLIENT, ChannelType.MIX))) {
try (InboundDecoder decoder = new InboundDecoder(recycler)) {
final ArrayList<Object> fragments = new ArrayList<>();
int bytesConsumed = decoder.decode(wrapAsReleasable(bytes), fragments::add);
int totalHeaderSize = TcpHeader.HEADER_SIZE + bytes.getInt(TcpHeader.VARIABLE_HEADER_SIZE_POSITION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;

import static org.elasticsearch.transport.InboundDecoder.ChannelType.SERVER;
import static org.elasticsearch.transport.RemoteClusterPortSettings.REMOTE_CLUSTER_PROFILE;
import static org.elasticsearch.xpack.core.XPackSettings.REMOTE_CLUSTER_CLIENT_SSL_ENABLED;
import static org.elasticsearch.xpack.core.XPackSettings.REMOTE_CLUSTER_CLIENT_SSL_PREFIX;
Expand All @@ -80,6 +79,7 @@ public class SecurityNetty4Transport extends Netty4Transport {
private final SslConfiguration remoteClusterClientSslConfiguration;
private final RemoteClusterClientBootstrapOptions remoteClusterClientBootstrapOptions;
private final CrossClusterAccessAuthenticationService crossClusterAccessAuthenticationService;
private final int maxHeaderSize;

public SecurityNetty4Transport(
final Settings settings,
Expand Down Expand Up @@ -120,6 +120,7 @@ public SecurityNetty4Transport(
this.remoteClusterClientSslConfiguration = null;
}
this.remoteClusterClientBootstrapOptions = RemoteClusterClientBootstrapOptions.fromSettings(settings);
this.maxHeaderSize = Math.toIntExact(RemoteClusterPortSettings.MAX_REQUEST_HEADER_SIZE.get(settings).getBytes());
}

@Override
Expand Down Expand Up @@ -167,7 +168,7 @@ protected InboundPipeline getInboundPipeline(Channel channel, boolean isRemoteCl
return new InboundPipeline(
getStatsTracker(),
threadPool.relativeTimeInMillisSupplier(),
new InboundDecoder(recycler, RemoteClusterPortSettings.MAX_REQUEST_HEADER_SIZE.get(settings), SERVER),
new InboundDecoder(recycler, maxHeaderSize, true),
new InboundAggregator(getInflightBreaker(), getRequestHandlers()::getHandler, ignoreDeserializationErrors()),
this::inboundMessage
) {
Expand Down