Skip to content

[8.x] Remove lots of redundant ref-counting from transport pipeline (#123390) #123554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: 8.19
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.network.ThreadWatchdog;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.transport.InboundPipeline;
Expand Down Expand Up @@ -51,8 +50,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
final ByteBuf buffer = (ByteBuf) msg;
Netty4TcpChannel channel = ctx.channel().attr(Netty4Transport.CHANNEL_KEY).get();
activityTracker.startActivity();
try (ReleasableBytesReference reference = Netty4Utils.toReleasableBytesReference(buffer)) {
pipeline.handleBytes(channel, reference);
try {
pipeline.handleBytes(channel, Netty4Utils.toReleasableBytesReference(buffer));
} finally {
activityTracker.stopActivity();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;

import java.util.List;
import io.netty.channel.ChannelInboundHandlerAdapter;

@ChannelHandler.Sharable
public class NettyByteBufSizer extends MessageToMessageDecoder<ByteBuf> {
public class NettyByteBufSizer extends ChannelInboundHandlerAdapter {

public static final NettyByteBufSizer INSTANCE = new NettyByteBufSizer();

Expand All @@ -26,14 +24,12 @@ private NettyByteBufSizer() {
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) {
int readableBytes = buf.readableBytes();
if (buf.capacity() >= 1024) {
ByteBuf resized = buf.discardReadBytes().capacity(readableBytes);
assert resized.readableBytes() == readableBytes;
out.add(resized.retain());
} else {
out.add(buf.retain());
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof ByteBuf buf && buf.capacity() >= 1024) {
int readableBytes = buf.readableBytes();
buf = buf.discardReadBytes().capacity(readableBytes);
assert buf.readableBytes() == readableBytes;
}
ctx.fireChannelRead(msg);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

import java.io.IOException;
import java.io.StreamCorruptedException;
import java.util.function.Consumer;

public class InboundDecoder implements Releasable {

Expand Down Expand Up @@ -53,7 +53,7 @@ public InboundDecoder(Recycler<BytesRef> recycler, ByteSizeValue maxHeaderSize,
this.channelType = channelType;
}

public int decode(ReleasableBytesReference reference, Consumer<Object> fragmentConsumer) throws IOException {
public int decode(ReleasableBytesReference reference, CheckedConsumer<Object, IOException> fragmentConsumer) throws IOException {
ensureOpen();
try {
return internalDecode(reference, fragmentConsumer);
Expand All @@ -63,7 +63,8 @@ public int decode(ReleasableBytesReference reference, Consumer<Object> fragmentC
}
}

public int internalDecode(ReleasableBytesReference reference, Consumer<Object> fragmentConsumer) throws IOException {
public int internalDecode(ReleasableBytesReference reference, CheckedConsumer<Object, IOException> fragmentConsumer)
throws IOException {
if (isOnHeader()) {
int messageLength = TcpTransport.readMessageLength(reference);
if (messageLength == -1) {
Expand Down Expand Up @@ -104,25 +105,28 @@ public int internalDecode(ReleasableBytesReference reference, Consumer<Object> f
}
int remainingToConsume = totalNetworkSize - bytesConsumed;
int maxBytesToConsume = Math.min(reference.length(), remainingToConsume);
ReleasableBytesReference retainedContent;
if (maxBytesToConsume == remainingToConsume) {
retainedContent = reference.retainedSlice(0, maxBytesToConsume);
} else {
retainedContent = reference.retain();
}

int bytesConsumedThisDecode = 0;
if (decompressor != null) {
bytesConsumedThisDecode += decompress(retainedContent);
bytesConsumedThisDecode += decompressor.decompress(
maxBytesToConsume == remainingToConsume ? reference.slice(0, maxBytesToConsume) : reference
);
bytesConsumed += bytesConsumedThisDecode;
ReleasableBytesReference decompressed;
while ((decompressed = decompressor.pollDecompressedPage(isDone())) != null) {
fragmentConsumer.accept(decompressed);
try (var buf = decompressed) {
fragmentConsumer.accept(buf);
}
}
} else {
bytesConsumedThisDecode += maxBytesToConsume;
bytesConsumed += maxBytesToConsume;
fragmentConsumer.accept(retainedContent);
if (maxBytesToConsume == remainingToConsume) {
try (ReleasableBytesReference retained = reference.retainedSlice(0, maxBytesToConsume)) {
fragmentConsumer.accept(retained);
}
} else {
fragmentConsumer.accept(reference);
}
}
if (isDone()) {
finishMessage(fragmentConsumer);
Expand All @@ -138,7 +142,7 @@ public void close() {
cleanDecodeState();
}

private void finishMessage(Consumer<Object> fragmentConsumer) {
private void finishMessage(CheckedConsumer<Object, IOException> fragmentConsumer) throws IOException {
cleanDecodeState();
fragmentConsumer.accept(END_CONTENT);
}
Expand All @@ -154,12 +158,6 @@ private void cleanDecodeState() {
}
}

private int decompress(ReleasableBytesReference content) throws IOException {
try (content) {
return decompressor.decompress(content);
}
}

private boolean isDone() {
return bytesConsumed == totalNetworkSize;
}
Expand Down
133 changes: 53 additions & 80 deletions server/src/main/java/org/elasticsearch/transport/InboundPipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@

import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;
import java.util.function.LongSupplier;

public class InboundPipeline implements Releasable {

private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);
private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true);

private final LongSupplier relativeTimeInMillis;
Expand Down Expand Up @@ -56,81 +55,74 @@ public void close() {

public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException {
if (uncaughtException != null) {
reference.close();
throw new IllegalStateException("Pipeline state corrupted by uncaught exception", uncaughtException);
}
try {
doHandleBytes(channel, reference);
channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
statsTracker.markBytesRead(reference.length());
if (isClosed) {
reference.close();
return;
}
pending.add(reference);
doHandleBytes(channel);
} catch (Exception e) {
uncaughtException = e;
throw e;
}
}

public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException {
channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
statsTracker.markBytesRead(reference.length());
pending.add(reference.retain());

final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
private void doHandleBytes(TcpChannel channel) throws IOException {
do {
CheckedConsumer<Object, IOException> decodeConsumer = f -> forwardFragment(channel, f);
int bytesDecoded = decoder.decode(pending.peekFirst(), decodeConsumer);
if (bytesDecoded == 0 && pending.size() > 1) {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
try (
ReleasableBytesReference toDecode = new ReleasableBytesReference(
CompositeBytesReference.of(bytesReferences),
() -> Releasables.closeExpectNoException(bytesReferences)
)
) {
bytesDecoded = decoder.decode(toDecode, decodeConsumer);
}
}

if (fragments.isEmpty()) {
continueHandling = false;
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
} else {
try {
forwardFragments(channel, fragments);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
break;
}
}
} while (pending.isEmpty() == false);
}

private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
headerReceived((Header) fragment);
} else if (fragment instanceof Compression.Scheme) {
assert aggregator.isAggregating();
aggregator.updateCompressionScheme((Compression.Scheme) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
} finally {
aggregated.decRef();
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
private void forwardFragment(TcpChannel channel, Object fragment) throws IOException {
if (fragment instanceof Header) {
headerReceived((Header) fragment);
} else if (fragment instanceof Compression.Scheme) {
assert aggregator.isAggregating();
aggregator.updateCompressionScheme((Compression.Scheme) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
} finally {
aggregated.decRef();
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}

Expand All @@ -139,25 +131,6 @@ protected void headerReceived(Header header) {
aggregator.headerReceived(header);
}

private static boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeExpectNoException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ public void testDecode() throws IOException {
assertEquals(messageBytes, content);
// Ref count is incremented since the bytes are forwarded as a fragment
assertTrue(releasable2.hasReferences());
releasable2.decRef();
assertTrue(releasable2.hasReferences());
assertTrue(releasable2.decRef());
assertEquals(InboundDecoder.END_CONTENT, endMarker);
}
Expand Down Expand Up @@ -433,7 +431,12 @@ public void testCompressedDecode() throws IOException {

final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed);
final ReleasableBytesReference releasable2 = wrapAsReleasable(bytes2);
int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
int bytesConsumed2 = decoder.decode(releasable2, e -> {
fragments.add(e);
if (e instanceof ReleasableBytesReference reference) {
reference.retain();
}
});
assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2);

final Object compressionScheme = fragments.get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,11 @@ public void testPipelineHandling() throws IOException {
final int remainingBytes = networkBytes.length() - currentOffset;
final int bytesToRead = Math.min(randomIntBetween(1, 32 * 1024), remainingBytes);
final BytesReference slice = networkBytes.slice(currentOffset, bytesToRead);
try (ReleasableBytesReference reference = new ReleasableBytesReference(slice, () -> {})) {
toRelease.add(reference);
bytesReceived += reference.length();
pipeline.handleBytes(channel, reference);
currentOffset += bytesToRead;
}
ReleasableBytesReference reference = new ReleasableBytesReference(slice, () -> {});
toRelease.add(reference);
bytesReceived += reference.length();
pipeline.handleBytes(channel, reference);
currentOffset += bytesToRead;
}

final int messages = expected.size();
Expand Down Expand Up @@ -288,13 +287,12 @@ public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
final Releasable releasable = () -> bodyReleased.set(true);
final int from = totalHeaderSize - 1;
final BytesReference partHeaderPartBody = reference.slice(from, reference.length() - from - 1);
try (ReleasableBytesReference slice = new ReleasableBytesReference(partHeaderPartBody, releasable)) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
pipeline.handleBytes(new FakeTcpChannel(), new ReleasableBytesReference(partHeaderPartBody, releasable));
assertFalse(bodyReleased.get());
try (ReleasableBytesReference slice = new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
pipeline.handleBytes(
new FakeTcpChannel(),
new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)
);
assertTrue(bodyReleased.get());
}
}
Expand Down