From 2964b154d61bce91f6450c5b58feeed4885a04ba Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Wed, 26 Feb 2025 23:11:17 +0100 Subject: [PATCH] Remove lots of redundant ref-counting from transport pipeline (#123390) We can do with a whole lot less in ref-counting, avoiding lots of contention and speeding up the logic in general by only incrementing ref-counts where ownership is unclear while avoiding count changes on obvious "moves". --- .../netty4/Netty4MessageInboundHandler.java | 5 +- .../transport/netty4/NettyByteBufSizer.java | 20 ++- .../transport/InboundDecoder.java | 38 +++-- .../transport/InboundPipeline.java | 133 +++++++----------- .../transport/InboundDecoderTests.java | 9 +- .../transport/InboundPipelineTests.java | 22 ++- 6 files changed, 97 insertions(+), 130 deletions(-) diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageInboundHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageInboundHandler.java index 8fdb7051e2be6..46f810ed2d9eb 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageInboundHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageInboundHandler.java @@ -14,7 +14,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.network.ThreadWatchdog; import org.elasticsearch.core.Releasables; import org.elasticsearch.transport.InboundPipeline; @@ -51,8 +50,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception final ByteBuf buffer = (ByteBuf) msg; Netty4TcpChannel channel = ctx.channel().attr(Netty4Transport.CHANNEL_KEY).get(); activityTracker.startActivity(); - try (ReleasableBytesReference reference = Netty4Utils.toReleasableBytesReference(buffer)) { - pipeline.handleBytes(channel, reference); + try { + pipeline.handleBytes(channel, Netty4Utils.toReleasableBytesReference(buffer)); } finally { activityTracker.stopActivity(); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyByteBufSizer.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyByteBufSizer.java index 2d62f8eb19e0b..4a9be0acaaa4c 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyByteBufSizer.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyByteBufSizer.java @@ -12,12 +12,10 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageDecoder; - -import java.util.List; +import io.netty.channel.ChannelInboundHandlerAdapter; @ChannelHandler.Sharable -public class NettyByteBufSizer extends MessageToMessageDecoder { +public class NettyByteBufSizer extends ChannelInboundHandlerAdapter { public static final NettyByteBufSizer INSTANCE = new NettyByteBufSizer(); @@ -26,14 +24,12 @@ private NettyByteBufSizer() { } @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List out) { - int readableBytes = buf.readableBytes(); - if (buf.capacity() >= 1024) { - ByteBuf resized = buf.discardReadBytes().capacity(readableBytes); - assert resized.readableBytes() == readableBytes; - out.add(resized.retain()); - } else { - out.add(buf.retain()); + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf buf && buf.capacity() >= 1024) { + int readableBytes = buf.readableBytes(); + buf = buf.discardReadBytes().capacity(readableBytes); + assert buf.readableBytes() == readableBytes; } + ctx.fireChannelRead(msg); } } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java index e2a1b010bad06..d1afafa7c1aeb 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java @@ -18,12 +18,12 @@ import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.io.StreamCorruptedException; -import java.util.function.Consumer; public class InboundDecoder implements Releasable { @@ -53,7 +53,7 @@ public InboundDecoder(Recycler recycler, ByteSizeValue maxHeaderSize, this.channelType = channelType; } - public int decode(ReleasableBytesReference reference, Consumer fragmentConsumer) throws IOException { + public int decode(ReleasableBytesReference reference, CheckedConsumer fragmentConsumer) throws IOException { ensureOpen(); try { return internalDecode(reference, fragmentConsumer); @@ -63,7 +63,8 @@ public int decode(ReleasableBytesReference reference, Consumer fragmentC } } - public int internalDecode(ReleasableBytesReference reference, Consumer fragmentConsumer) throws IOException { + public int internalDecode(ReleasableBytesReference reference, CheckedConsumer fragmentConsumer) + throws IOException { if (isOnHeader()) { int messageLength = TcpTransport.readMessageLength(reference); if (messageLength == -1) { @@ -104,25 +105,28 @@ public int internalDecode(ReleasableBytesReference reference, Consumer f } int remainingToConsume = totalNetworkSize - bytesConsumed; int maxBytesToConsume = Math.min(reference.length(), remainingToConsume); - ReleasableBytesReference retainedContent; - if (maxBytesToConsume == remainingToConsume) { - retainedContent = reference.retainedSlice(0, maxBytesToConsume); - } else { - retainedContent = reference.retain(); - } - int bytesConsumedThisDecode = 0; if (decompressor != null) { - bytesConsumedThisDecode += decompress(retainedContent); + bytesConsumedThisDecode += decompressor.decompress( + maxBytesToConsume == remainingToConsume ? reference.slice(0, maxBytesToConsume) : reference + ); bytesConsumed += bytesConsumedThisDecode; ReleasableBytesReference decompressed; while ((decompressed = decompressor.pollDecompressedPage(isDone())) != null) { - fragmentConsumer.accept(decompressed); + try (var buf = decompressed) { + fragmentConsumer.accept(buf); + } } } else { bytesConsumedThisDecode += maxBytesToConsume; bytesConsumed += maxBytesToConsume; - fragmentConsumer.accept(retainedContent); + if (maxBytesToConsume == remainingToConsume) { + try (ReleasableBytesReference retained = reference.retainedSlice(0, maxBytesToConsume)) { + fragmentConsumer.accept(retained); + } + } else { + fragmentConsumer.accept(reference); + } } if (isDone()) { finishMessage(fragmentConsumer); @@ -138,7 +142,7 @@ public void close() { cleanDecodeState(); } - private void finishMessage(Consumer fragmentConsumer) { + private void finishMessage(CheckedConsumer fragmentConsumer) throws IOException { cleanDecodeState(); fragmentConsumer.accept(END_CONTENT); } @@ -154,12 +158,6 @@ private void cleanDecodeState() { } } - private int decompress(ReleasableBytesReference content) throws IOException { - try (content) { - return decompressor.decompress(content); - } - } - private boolean isDone() { return bytesConsumed == totalNetworkSize; } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java index 35665e95c8030..abc3e29727b4b 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java @@ -11,18 +11,17 @@ import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.function.BiConsumer; import java.util.function.LongSupplier; public class InboundPipeline implements Releasable { - private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); private final LongSupplier relativeTimeInMillis; @@ -56,81 +55,74 @@ public void close() { public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { if (uncaughtException != null) { + reference.close(); throw new IllegalStateException("Pipeline state corrupted by uncaught exception", uncaughtException); } try { - doHandleBytes(channel, reference); + channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); + statsTracker.markBytesRead(reference.length()); + if (isClosed) { + reference.close(); + return; + } + pending.add(reference); + doHandleBytes(channel); } catch (Exception e) { uncaughtException = e; throw e; } } - public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { - channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); - statsTracker.markBytesRead(reference.length()); - pending.add(reference.retain()); - - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { - continueDecoding = false; - } - } else { - continueDecoding = false; - } + private void doHandleBytes(TcpChannel channel) throws IOException { + do { + CheckedConsumer decodeConsumer = f -> forwardFragment(channel, f); + int bytesDecoded = decoder.decode(pending.peekFirst(), decodeConsumer); + if (bytesDecoded == 0 && pending.size() > 1) { + final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; + int index = 0; + for (ReleasableBytesReference pendingReference : pending) { + bytesReferences[index] = pendingReference.retain(); + ++index; + } + try ( + ReleasableBytesReference toDecode = new ReleasableBytesReference( + CompositeBytesReference.of(bytesReferences), + () -> Releasables.closeExpectNoException(bytesReferences) + ) + ) { + bytesDecoded = decoder.decode(toDecode, decodeConsumer); } } - - if (fragments.isEmpty()) { - continueHandling = false; + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); } else { - try { - forwardFragments(channel, fragments); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); - } - } - fragments.clear(); - } + break; } - } + } while (pending.isEmpty() == false); } - private void forwardFragments(TcpChannel channel, ArrayList fragments) throws IOException { - for (Object fragment : fragments) { - if (fragment instanceof Header) { - headerReceived((Header) fragment); - } else if (fragment instanceof Compression.Scheme) { - assert aggregator.isAggregating(); - aggregator.updateCompressionScheme((Compression.Scheme) fragment); - } else if (fragment == InboundDecoder.PING) { - assert aggregator.isAggregating() == false; - messageHandler.accept(channel, PING_MESSAGE); - } else if (fragment == InboundDecoder.END_CONTENT) { - assert aggregator.isAggregating(); - InboundMessage aggregated = aggregator.finishAggregation(); - try { - statsTracker.markMessageReceived(); - messageHandler.accept(channel, aggregated); - } finally { - aggregated.decRef(); - } - } else { - assert aggregator.isAggregating(); - assert fragment instanceof ReleasableBytesReference; - aggregator.aggregate((ReleasableBytesReference) fragment); + private void forwardFragment(TcpChannel channel, Object fragment) throws IOException { + if (fragment instanceof Header) { + headerReceived((Header) fragment); + } else if (fragment instanceof Compression.Scheme) { + assert aggregator.isAggregating(); + aggregator.updateCompressionScheme((Compression.Scheme) fragment); + } else if (fragment == InboundDecoder.PING) { + assert aggregator.isAggregating() == false; + messageHandler.accept(channel, PING_MESSAGE); + } else if (fragment == InboundDecoder.END_CONTENT) { + assert aggregator.isAggregating(); + InboundMessage aggregated = aggregator.finishAggregation(); + try { + statsTracker.markMessageReceived(); + messageHandler.accept(channel, aggregated); + } finally { + aggregated.decRef(); } + } else { + assert aggregator.isAggregating(); + assert fragment instanceof ReleasableBytesReference; + aggregator.aggregate((ReleasableBytesReference) fragment); } } @@ -139,25 +131,6 @@ protected void headerReceived(Header header) { aggregator.headerReceived(header); } - private static boolean endOfMessage(Object fragment) { - return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; - } - - private ReleasableBytesReference getPendingBytes() { - if (pending.size() == 1) { - return pending.peekFirst().retain(); - } else { - final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; - int index = 0; - for (ReleasableBytesReference pendingReference : pending) { - bytesReferences[index] = pendingReference.retain(); - ++index; - } - final Releasable releasable = () -> Releasables.closeExpectNoException(bytesReferences); - return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); - } - } - private void releasePendingBytes(int bytesConsumed) { int bytesToRelease = bytesConsumed; while (bytesToRelease != 0) { diff --git a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java index 61a62d0e2e198..118694b654535 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java @@ -117,8 +117,6 @@ public void testDecode() throws IOException { assertEquals(messageBytes, content); // Ref count is incremented since the bytes are forwarded as a fragment assertTrue(releasable2.hasReferences()); - releasable2.decRef(); - assertTrue(releasable2.hasReferences()); assertTrue(releasable2.decRef()); assertEquals(InboundDecoder.END_CONTENT, endMarker); } @@ -433,7 +431,12 @@ public void testCompressedDecode() throws IOException { final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed); final ReleasableBytesReference releasable2 = wrapAsReleasable(bytes2); - int bytesConsumed2 = decoder.decode(releasable2, fragments::add); + int bytesConsumed2 = decoder.decode(releasable2, e -> { + fragments.add(e); + if (e instanceof ReleasableBytesReference reference) { + reference.retain(); + } + }); assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2); final Object compressionScheme = fragments.get(0); diff --git a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java index 282cb720f52f3..d0c6cd8b00ff5 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java @@ -159,12 +159,11 @@ public void testPipelineHandling() throws IOException { final int remainingBytes = networkBytes.length() - currentOffset; final int bytesToRead = Math.min(randomIntBetween(1, 32 * 1024), remainingBytes); final BytesReference slice = networkBytes.slice(currentOffset, bytesToRead); - try (ReleasableBytesReference reference = new ReleasableBytesReference(slice, () -> {})) { - toRelease.add(reference); - bytesReceived += reference.length(); - pipeline.handleBytes(channel, reference); - currentOffset += bytesToRead; - } + ReleasableBytesReference reference = new ReleasableBytesReference(slice, () -> {}); + toRelease.add(reference); + bytesReceived += reference.length(); + pipeline.handleBytes(channel, reference); + currentOffset += bytesToRead; } final int messages = expected.size(); @@ -288,13 +287,12 @@ public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { final Releasable releasable = () -> bodyReleased.set(true); final int from = totalHeaderSize - 1; final BytesReference partHeaderPartBody = reference.slice(from, reference.length() - from - 1); - try (ReleasableBytesReference slice = new ReleasableBytesReference(partHeaderPartBody, releasable)) { - pipeline.handleBytes(new FakeTcpChannel(), slice); - } + pipeline.handleBytes(new FakeTcpChannel(), new ReleasableBytesReference(partHeaderPartBody, releasable)); assertFalse(bodyReleased.get()); - try (ReleasableBytesReference slice = new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)) { - pipeline.handleBytes(new FakeTcpChannel(), slice); - } + pipeline.handleBytes( + new FakeTcpChannel(), + new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable) + ); assertTrue(bodyReleased.get()); } }