Skip to content

Add TLS support #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/main/java/com/rabbitmq/stream/impl/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.flush.FlushConsolidationHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
Expand Down Expand Up @@ -127,6 +128,7 @@
public class Client implements AutoCloseable {

public static final int DEFAULT_PORT = 5552;
public static final int DEFAULT_TLS_PORT = 5551;
static final OutboundEntityWriteCallback OUTBOUND_MESSAGE_WRITE_CALLBACK =
new OutboundMessageWriteCallback();
static final OutboundEntityWriteCallback OUTBOUND_MESSAGE_BATCH_WRITE_CALLBACK =
Expand Down Expand Up @@ -250,6 +252,13 @@ public void initChannel(SocketChannel ch) {
NETTY_HANDLER_FRAME_DECODER,
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
ch.pipeline().addLast(NETTY_HANDLER_STREAM, new StreamHandler());
if (parameters.sslContext != null) {
ch.pipeline()
.addFirst(
"ssl",
parameters.sslContext.newHandler(
ch.alloc(), parameters.host, parameters.port));
}
channelCustomizer.customize(ch);
}
});
Expand Down Expand Up @@ -1683,6 +1692,7 @@ public static class ClientParameters {
private ChannelCustomizer channelCustomizer = ch -> {};
private ChunkChecksum chunkChecksum = JdkChunkChecksum.CRC32_SINGLETON;
private MetricsCollector metricsCollector = NoOpMetricsCollector.SINGLETON;
private SslContext sslContext;

public ClientParameters host(String host) {
this.host = host;
Expand Down Expand Up @@ -1813,6 +1823,14 @@ public ClientParameters shutdownListener(ShutdownListener shutdownListener) {
return this;
}

public ClientParameters sslContext(SslContext sslContext) {
this.sslContext = sslContext;
if (this.port == DEFAULT_PORT && sslContext != null) {
this.port = DEFAULT_TLS_PORT;
}
return this;
}

ClientParameters duplicate() {
ClientParameters duplicate = new ClientParameters();
for (Field field : ClientParameters.class.getDeclaredFields()) {
Expand Down
70 changes: 52 additions & 18 deletions src/test/java/com/rabbitmq/stream/impl/OffsetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.stream.OffsetSpecification;
import com.rabbitmq.stream.impl.TestUtils.AlwaysTrustTrustManager;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -37,9 +42,11 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.net.ssl.SSLException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

@ExtendWith(TestUtils.StreamTestInfrastructureExtension.class)
Expand All @@ -60,10 +67,27 @@ static Map<String, String> subscriptionMap() {
return map;
}

static Stream<Arguments> offsetArguments() throws SSLException {
return sslContexts()
.flatMap(
sslContext -> subscriptionProperties().map(props -> Arguments.of(props, sslContext)));
}
;

static Stream<SslContext> sslContexts() throws SSLException {
List<SslContext> contexts = new ArrayList<>();
contexts.add(null);
if (TestUtils.tlsAvailable()) {
contexts.add(
SslContextBuilder.forClient().trustManager(new AlwaysTrustTrustManager()).build());
}
return contexts.stream();
}

@ParameterizedTest
@MethodSource("subscriptionProperties")
void offsetTypeFirstShouldStartConsumingFromBeginning(Map<String, String> subscriptionProperties)
throws Exception {
@MethodSource("offsetArguments")
void offsetTypeFirstShouldStartConsumingFromBeginning(
Map<String, String> subscriptionProperties, SslContext sslContext) throws Exception {
int messageCount = 50000;
TestUtils.publishAndWaitForConfirms(cf, messageCount, stream);
CountDownLatch latch = new CountDownLatch(messageCount);
Expand All @@ -72,6 +96,7 @@ void offsetTypeFirstShouldStartConsumingFromBeginning(Map<String, String> subscr
Client client =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client1, subscriptionId, offset12, messageCount1, dataSize) ->
client1.credit(subscriptionId, 1))
Expand Down Expand Up @@ -117,9 +142,9 @@ void amqpOffsetTypeFirstShouldStartConsumingFromBeginning() throws Exception {
}

@ParameterizedTest
@MethodSource("subscriptionProperties")
void offsetTypeLastShouldReturnLastChunk(Map<String, String> subscriptionProperties)
throws Exception {
@MethodSource("offsetArguments")
void offsetTypeLastShouldReturnLastChunk(
Map<String, String> subscriptionProperties, SslContext sslContext) throws Exception {
int messageCount = 50000;
long lastOffset = messageCount - 1;
TestUtils.publishAndWaitForConfirms(cf, messageCount, stream);
Expand All @@ -131,6 +156,7 @@ void offsetTypeLastShouldReturnLastChunk(Map<String, String> subscriptionPropert
Client client =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client1, subscriptionId, offset12, messageCount1, dataSize) -> {
client1.credit(subscriptionId, 1);
Expand Down Expand Up @@ -198,9 +224,9 @@ void amqpOffsetTypeLastShouldReturnLastChunk() throws Exception {
}

@ParameterizedTest
@MethodSource("subscriptionProperties")
void offsetTypeNextShouldReturnNewPublishedMessages(Map<String, String> subscriptionProperties)
throws Exception {
@MethodSource("offsetArguments")
void offsetTypeNextShouldReturnNewPublishedMessages(
Map<String, String> subscriptionProperties, SslContext sslContext) throws Exception {
int firstWaveMessageCount = 50000;
int secondWaveMessageCount = 20000;
int lastOffset = firstWaveMessageCount + secondWaveMessageCount - 1;
Expand All @@ -211,6 +237,7 @@ void offsetTypeNextShouldReturnNewPublishedMessages(Map<String, String> subscrip
Client client =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client1, subscriptionId, offset, messageCount1, dataSize) ->
client1.credit(subscriptionId, 1))
Expand Down Expand Up @@ -266,9 +293,9 @@ void amqpOffsetTypeNextShouldReturnNewPublishedMessages() throws Exception {
}

@ParameterizedTest
@MethodSource("subscriptionProperties")
void offsetTypeOffsetShouldStartConsumingFromOffset(Map<String, String> subscriptionProperties)
throws Exception {
@MethodSource("offsetArguments")
void offsetTypeOffsetShouldStartConsumingFromOffset(
Map<String, String> subscriptionProperties, SslContext sslContext) throws Exception {
int messageCount = 50000;
TestUtils.publishAndWaitForConfirms(cf, messageCount, stream);
int offset = messageCount / 10;
Expand All @@ -278,6 +305,7 @@ void offsetTypeOffsetShouldStartConsumingFromOffset(Map<String, String> subscrip
Client client =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client1, subscriptionId, offset12, messageCount1, dataSize) ->
client1.credit(subscriptionId, 1))
Expand Down Expand Up @@ -324,9 +352,9 @@ void amqpOffsetTypeOffsetShouldStartConsumingFromOffset() throws Exception {
}

@ParameterizedTest
@MethodSource("subscriptionProperties")
@MethodSource("offsetArguments")
void offsetTypeTimestampShouldStartConsumingFromTimestamp(
Map<String, String> subscriptionProperties) throws Exception {
Map<String, String> subscriptionProperties, SslContext sslContext) throws Exception {
int firstWaveMessageCount = 50000;
int secondWaveMessageCount = 20000;
int lastOffset = firstWaveMessageCount + secondWaveMessageCount - 1;
Expand All @@ -342,6 +370,7 @@ void offsetTypeTimestampShouldStartConsumingFromTimestamp(
Client client =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client1, subscriptionId, offset, messageCount1, dataSize) ->
client1.credit(subscriptionId, 1))
Expand Down Expand Up @@ -442,8 +471,9 @@ void filterSmallerOffsets() throws Exception {
}
}

@Test
void consumeFromTail() throws Exception {
@ParameterizedTest
@MethodSource("sslContexts")
void consumeFromTail(SslContext sslContext) throws Exception {
int messageCount = 10000;
CountDownLatch firstWaveLatch = new CountDownLatch(messageCount);
CountDownLatch secondWaveLatch = new CountDownLatch(messageCount * 2);
Expand Down Expand Up @@ -474,6 +504,7 @@ void consumeFromTail() throws Exception {
Client consumer =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client, subscriptionId, offset, messageCount1, dataSize) ->
client.credit(subscriptionId, 1))
Expand Down Expand Up @@ -502,8 +533,10 @@ void consumeFromTail() throws Exception {
.forEach(v -> assertThat(v).startsWith("second wave").doesNotStartWith("first wave"));
}

@Test
void shouldReachTailWhenPublisherStopWhileConsumerIsBehind() throws Exception {
@ParameterizedTest
@MethodSource("sslContexts")
void shouldReachTailWhenPublisherStopWhileConsumerIsBehind(SslContext sslContext)
throws Exception {
int messageCount = 100000;
int messageLimit = messageCount * 2;
AtomicLong lastConfirmed = new AtomicLong();
Expand All @@ -524,6 +557,7 @@ void shouldReachTailWhenPublisherStopWhileConsumerIsBehind() throws Exception {
Client consumer =
cf.get(
new Client.ClientParameters()
.sslContext(sslContext)
.chunkListener(
(client, subscriptionId, offset, msgCount, dataSize) -> client.credit(b(0), 1))
.messageListener(
Expand Down
Loading