Skip to content

Commit 7130837

Browse files
committed
LocalChannel write when peer closed leak
Motivation: If LocalChannel doWrite executes while the peer's state changes from CONNECTED to CLOSED it is possible that some promise's won't be completed and buffers will be leaked. Modifications: - Check the peer's state in doWrite to avoid a race condition Result: All write operations should release, and the associated promise should be completed.
1 parent fd70dd6 commit 7130837

File tree

2 files changed

+198
-27
lines changed

2 files changed

+198
-27
lines changed

transport/src/main/java/io/netty/channel/local/LocalChannel.java

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
import io.netty.channel.EventLoop;
2828
import io.netty.channel.SingleThreadEventLoop;
2929
import io.netty.util.ReferenceCountUtil;
30-
import io.netty.util.concurrent.SingleThreadEventExecutor;
3130
import io.netty.util.concurrent.Future;
31+
import io.netty.util.concurrent.SingleThreadEventExecutor;
32+
import io.netty.util.internal.EmptyArrays;
3233
import io.netty.util.internal.InternalThreadLocalMap;
3334
import io.netty.util.internal.OneTimeTask;
3435
import io.netty.util.internal.PlatformDependent;
@@ -50,9 +51,10 @@ public class LocalChannel extends AbstractChannel {
5051
private static final AtomicReferenceFieldUpdater<LocalChannel, Future> FINISH_READ_FUTURE_UPDATER;
5152
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
5253
private static final int MAX_READER_STACK_DEPTH = 8;
54+
private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
5355

5456
private final ChannelConfig config = new DefaultChannelConfig(this);
55-
// To futher optimize this we could write our own SPSC queue.
57+
// To further optimize this we could write our own SPSC queue.
5658
private final Queue<Object> inboundBuffer = PlatformDependent.newMpscQueue();
5759
private final Runnable readTask = new Runnable() {
5860
@Override
@@ -94,6 +96,7 @@ public void run() {
9496
AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture");
9597
}
9698
FINISH_READ_FUTURE_UPDATER = finishReadFutureUpdater;
99+
CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
97100
}
98101

99102
public LocalChannel() {
@@ -216,18 +219,29 @@ protected void doDisconnect() throws Exception {
216219
protected void doClose() throws Exception {
217220
final LocalChannel peer = this.peer;
218221
if (state <= 2) {
219-
// To preserve ordering of events we must process any pending reads
220-
if (writeInProgress && peer != null) {
221-
finishPeerRead(peer);
222-
}
223222
// Update all internal state before the closeFuture is notified.
224223
if (localAddress != null) {
225224
if (parent() == null) {
226225
LocalChannelRegistry.unregister(localAddress);
227226
}
228227
localAddress = null;
229228
}
229+
230+
// State change must happen before finishPeerRead to ensure writes are released either in doWrite or
231+
// channelRead.
230232
state = 3;
233+
234+
ChannelPromise promise = connectPromise;
235+
if (promise != null) {
236+
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
237+
promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
238+
connectPromise = null;
239+
}
240+
241+
// To preserve ordering of events we must process any pending reads
242+
if (writeInProgress && peer != null) {
243+
finishPeerRead(peer);
244+
}
231245
}
232246

233247
if (peer != null && peer.isActive()) {
@@ -239,12 +253,18 @@ protected void doClose() throws Exception {
239253
} else {
240254
// This value may change, and so we should save it before executing the Runnable.
241255
final boolean peerWriteInProgress = peer.writeInProgress;
242-
peer.eventLoop().execute(new OneTimeTask() {
243-
@Override
244-
public void run() {
245-
doPeerClose(peer, peerWriteInProgress);
246-
}
247-
});
256+
try {
257+
peer.eventLoop().execute(new OneTimeTask() {
258+
@Override
259+
public void run() {
260+
doPeerClose(peer, peerWriteInProgress);
261+
}
262+
});
263+
} catch (RuntimeException e) {
264+
// The peer close may attempt to drain this.inboundBuffers. If that fails make sure it is drained.
265+
releaseInboundBuffers();
266+
throw e;
267+
}
248268
}
249269
this.peer = null;
250270
}
@@ -293,7 +313,12 @@ protected void doBeginRead() throws Exception {
293313
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
294314
}
295315
} else {
296-
eventLoop().execute(readTask);
316+
try {
317+
eventLoop().execute(readTask);
318+
} catch (RuntimeException e) {
319+
releaseInboundBuffers();
320+
throw e;
321+
}
297322
}
298323
}
299324

@@ -303,7 +328,7 @@ protected void doWrite(ChannelOutboundBuffer in) throws Exception {
303328
throw new NotYetConnectedException();
304329
}
305330
if (state > 2) {
306-
throw new ClosedChannelException();
331+
throw CLOSED_CHANNEL_EXCEPTION;
307332
}
308333

309334
final LocalChannel peer = this.peer;
@@ -316,8 +341,14 @@ protected void doWrite(ChannelOutboundBuffer in) throws Exception {
316341
break;
317342
}
318343
try {
319-
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
320-
in.remove();
344+
// It is possible the peer could have closed while we are writing, and in this case we should
345+
// simulate real socket behavior and ensure the write operation is failed.
346+
if (peer.state == 2) {
347+
peer.inboundBuffer.add(ReferenceCountUtil.retain(msg));
348+
in.remove();
349+
} else {
350+
in.remove(CLOSED_CHANNEL_EXCEPTION);
351+
}
321352
} catch (Throwable cause) {
322353
in.remove(cause);
323354
}
@@ -352,10 +383,25 @@ public void run() {
352383
finishPeerRead0(peer);
353384
}
354385
};
355-
if (peer.writeInProgress) {
356-
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
357-
} else {
358-
peer.eventLoop().execute(finishPeerReadTask);
386+
try {
387+
if (peer.writeInProgress) {
388+
peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask);
389+
} else {
390+
peer.eventLoop().execute(finishPeerReadTask);
391+
}
392+
} catch (RuntimeException e) {
393+
peer.releaseInboundBuffers();
394+
throw e;
395+
}
396+
}
397+
398+
private void releaseInboundBuffers() {
399+
for (;;) {
400+
Object o = inboundBuffer.poll();
401+
if (o == null) {
402+
break;
403+
}
404+
ReferenceCountUtil.release(o);
359405
}
360406
}
361407

transport/src/test/java/io/netty/channel/local/LocalChannelTest.java

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ public void testCloseInWritePromiseCompletePreservesOrder() throws InterruptedEx
339339
@Override
340340
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
341341
if (msg.equals(data)) {
342-
messageLatch.countDown();
343342
ReferenceCountUtil.safeRelease(msg);
343+
messageLatch.countDown();
344344
} else {
345345
super.channelRead(ctx, msg);
346346
}
@@ -408,8 +408,8 @@ public void testWriteInWritePromiseCompletePreservesOrder() throws InterruptedEx
408408
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
409409
final long count = messageLatch.getCount();
410410
if ((data.equals(msg) && count == 2) || (data2.equals(msg) && count == 1)) {
411-
messageLatch.countDown();
412411
ReferenceCountUtil.safeRelease(msg);
412+
messageLatch.countDown();
413413
} else {
414414
super.channelRead(ctx, msg);
415415
}
@@ -468,8 +468,8 @@ public void testPeerWriteInWritePromiseCompleteDifferentEventLoopPreservesOrder(
468468
@Override
469469
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
470470
if (data2.equals(msg)) {
471-
messageLatch.countDown();
472471
ReferenceCountUtil.safeRelease(msg);
472+
messageLatch.countDown();
473473
} else {
474474
super.channelRead(ctx, msg);
475475
}
@@ -485,8 +485,8 @@ public void initChannel(LocalChannel ch) throws Exception {
485485
@Override
486486
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
487487
if (data.equals(msg)) {
488-
messageLatch.countDown();
489488
ReferenceCountUtil.safeRelease(msg);
489+
messageLatch.countDown();
490490
} else {
491491
super.channelRead(ctx, msg);
492492
}
@@ -550,8 +550,8 @@ public void testPeerWriteInWritePromiseCompleteSameEventLoopPreservesOrder() thr
550550
@Override
551551
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
552552
if (data2.equals(msg) && messageLatch.getCount() == 1) {
553-
messageLatch.countDown();
554553
ReferenceCountUtil.safeRelease(msg);
554+
messageLatch.countDown();
555555
} else {
556556
super.channelRead(ctx, msg);
557557
}
@@ -567,8 +567,8 @@ public void initChannel(LocalChannel ch) throws Exception {
567567
@Override
568568
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
569569
if (data.equals(msg) && messageLatch.getCount() == 2) {
570-
messageLatch.countDown();
571570
ReferenceCountUtil.safeRelease(msg);
571+
messageLatch.countDown();
572572
} else {
573573
super.channelRead(ctx, msg);
574574
}
@@ -641,8 +641,8 @@ public void initChannel(LocalChannel ch) throws Exception {
641641
@Override
642642
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
643643
if (msg.equals(data)) {
644-
messageLatch.countDown();
645644
ReferenceCountUtil.safeRelease(msg);
645+
messageLatch.countDown();
646646
} else {
647647
super.channelRead(ctx, msg);
648648
}
@@ -697,6 +697,130 @@ public void operationComplete(ChannelFuture future) throws Exception {
697697
}
698698
}
699699

700+
@Test
701+
public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException {
702+
Bootstrap cb = new Bootstrap();
703+
ServerBootstrap sb = new ServerBootstrap();
704+
final CountDownLatch serverMessageLatch = new CountDownLatch(1);
705+
final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1);
706+
final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1);
707+
final CountDownLatch writeFailLatch = new CountDownLatch(1);
708+
final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]);
709+
final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]);
710+
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
711+
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
712+
713+
try {
714+
cb.group(group1)
715+
.channel(LocalChannel.class)
716+
.handler(new TestHandler());
717+
718+
sb.group(group2)
719+
.channel(LocalServerChannel.class)
720+
.childHandler(new ChannelInitializer<LocalChannel>() {
721+
@Override
722+
public void initChannel(LocalChannel ch) throws Exception {
723+
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
724+
@Override
725+
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
726+
if (data.equals(msg)) {
727+
ReferenceCountUtil.safeRelease(msg);
728+
serverMessageLatch.countDown();
729+
} else {
730+
super.channelRead(ctx, msg);
731+
}
732+
}
733+
});
734+
serverChannelRef.set(ch);
735+
serverChannelLatch.countDown();
736+
}
737+
});
738+
739+
Channel sc = null;
740+
Channel cc = null;
741+
try {
742+
// Start server
743+
sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel();
744+
745+
// Connect to the server
746+
cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel();
747+
assertTrue(serverChannelLatch.await(5, SECONDS));
748+
749+
final Channel ccCpy = cc;
750+
final Channel serverChannelCpy = serverChannelRef.get();
751+
serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch);
752+
ccCpy.closeFuture().addListener(clientChannelCloseLatch);
753+
754+
// Make sure a write operation is executed in the eventloop
755+
cc.pipeline().lastContext().executor().execute(new OneTimeTask() {
756+
@Override
757+
public void run() {
758+
ccCpy.writeAndFlush(data.duplicate().retain(), ccCpy.newPromise())
759+
.addListener(new ChannelFutureListener() {
760+
@Override
761+
public void operationComplete(ChannelFuture future) throws Exception {
762+
serverChannelCpy.eventLoop().execute(new OneTimeTask() {
763+
@Override
764+
public void run() {
765+
// The point of this test is to write while the peer is closed, so we should
766+
// ensure the peer is actually closed before we write.
767+
int waitCount = 0;
768+
while (ccCpy.isOpen()) {
769+
try {
770+
Thread.sleep(50);
771+
} catch (InterruptedException ignored) {
772+
// ignored
773+
}
774+
if (++waitCount > 5) {
775+
fail();
776+
}
777+
}
778+
serverChannelCpy.writeAndFlush(data2.duplicate().retain(),
779+
serverChannelCpy.newPromise())
780+
.addListener(new ChannelFutureListener() {
781+
@Override
782+
public void operationComplete(ChannelFuture future) throws Exception {
783+
if (!future.isSuccess() &&
784+
future.cause() instanceof ClosedChannelException) {
785+
writeFailLatch.countDown();
786+
}
787+
}
788+
});
789+
}
790+
});
791+
ccCpy.close();
792+
}
793+
});
794+
}
795+
});
796+
797+
assertTrue(serverMessageLatch.await(5, SECONDS));
798+
assertTrue(writeFailLatch.await(5, SECONDS));
799+
assertTrue(serverChannelCloseLatch.await(5, SECONDS));
800+
assertTrue(clientChannelCloseLatch.await(5, SECONDS));
801+
assertFalse(ccCpy.isOpen());
802+
assertFalse(serverChannelCpy.isOpen());
803+
} finally {
804+
closeChannel(cc);
805+
closeChannel(sc);
806+
}
807+
} finally {
808+
data.release();
809+
data2.release();
810+
}
811+
}
812+
813+
private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener {
814+
public LatchChannelFutureListener(int count) {
815+
super(count);
816+
}
817+
818+
@Override
819+
public void operationComplete(ChannelFuture future) throws Exception {
820+
countDown();
821+
}
822+
}
823+
700824
private static void closeChannel(Channel cc) {
701825
if (cc != null) {
702826
cc.close().syncUninterruptibly();
@@ -707,6 +831,7 @@ static class TestHandler extends ChannelInboundHandlerAdapter {
707831
@Override
708832
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
709833
logger.info(String.format("Received mesage: %s", msg));
834+
ReferenceCountUtil.safeRelease(msg);
710835
}
711836
}
712837
}

0 commit comments

Comments
 (0)