diff --git a/src/main/java/com/rabbitmq/stream/impl/Client.java b/src/main/java/com/rabbitmq/stream/impl/Client.java index 47b691fa99..3864c199d2 100644 --- a/src/main/java/com/rabbitmq/stream/impl/Client.java +++ b/src/main/java/com/rabbitmq/stream/impl/Client.java @@ -82,6 +82,7 @@ import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.flush.FlushConsolidationHandler; +import io.netty.handler.ssl.SslContext; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import io.netty.handler.timeout.IdleStateHandler; @@ -127,6 +128,7 @@ public class Client implements AutoCloseable { public static final int DEFAULT_PORT = 5552; + public static final int DEFAULT_TLS_PORT = 5551; static final OutboundEntityWriteCallback OUTBOUND_MESSAGE_WRITE_CALLBACK = new OutboundMessageWriteCallback(); static final OutboundEntityWriteCallback OUTBOUND_MESSAGE_BATCH_WRITE_CALLBACK = @@ -250,6 +252,13 @@ public void initChannel(SocketChannel ch) { NETTY_HANDLER_FRAME_DECODER, new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)); ch.pipeline().addLast(NETTY_HANDLER_STREAM, new StreamHandler()); + if (parameters.sslContext != null) { + ch.pipeline() + .addFirst( + "ssl", + parameters.sslContext.newHandler( + ch.alloc(), parameters.host, parameters.port)); + } channelCustomizer.customize(ch); } }); @@ -1683,6 +1692,7 @@ public static class ClientParameters { private ChannelCustomizer channelCustomizer = ch -> {}; private ChunkChecksum chunkChecksum = JdkChunkChecksum.CRC32_SINGLETON; private MetricsCollector metricsCollector = NoOpMetricsCollector.SINGLETON; + private SslContext sslContext; public ClientParameters host(String host) { this.host = host; @@ -1813,6 +1823,14 @@ public ClientParameters shutdownListener(ShutdownListener shutdownListener) { return this; } + public ClientParameters sslContext(SslContext sslContext) { + this.sslContext = sslContext; + if (this.port == DEFAULT_PORT && sslContext != null) { + this.port = DEFAULT_TLS_PORT; + } + return this; + } + ClientParameters duplicate() { ClientParameters duplicate = new ClientParameters(); for (Field field : ClientParameters.class.getDeclaredFields()) { diff --git a/src/test/java/com/rabbitmq/stream/impl/OffsetTest.java b/src/test/java/com/rabbitmq/stream/impl/OffsetTest.java index 9943048a3b..55164a2b45 100644 --- a/src/test/java/com/rabbitmq/stream/impl/OffsetTest.java +++ b/src/test/java/com/rabbitmq/stream/impl/OffsetTest.java @@ -22,11 +22,16 @@ import com.rabbitmq.client.Connection; import com.rabbitmq.client.ConnectionFactory; import com.rabbitmq.stream.OffsetSpecification; +import com.rabbitmq.stream.impl.TestUtils.AlwaysTrustTrustManager; import io.netty.channel.EventLoopGroup; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -37,9 +42,11 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import java.util.stream.Stream; +import javax.net.ssl.SSLException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @ExtendWith(TestUtils.StreamTestInfrastructureExtension.class) @@ -60,10 +67,27 @@ static Map subscriptionMap() { return map; } + static Stream offsetArguments() throws SSLException { + return sslContexts() + .flatMap( + sslContext -> subscriptionProperties().map(props -> Arguments.of(props, sslContext))); + } + ; + + static Stream sslContexts() throws SSLException { + List contexts = new ArrayList<>(); + contexts.add(null); + if (TestUtils.tlsAvailable()) { + contexts.add( + SslContextBuilder.forClient().trustManager(new AlwaysTrustTrustManager()).build()); + } + return contexts.stream(); + } + @ParameterizedTest - @MethodSource("subscriptionProperties") - void offsetTypeFirstShouldStartConsumingFromBeginning(Map subscriptionProperties) - throws Exception { + @MethodSource("offsetArguments") + void offsetTypeFirstShouldStartConsumingFromBeginning( + Map subscriptionProperties, SslContext sslContext) throws Exception { int messageCount = 50000; TestUtils.publishAndWaitForConfirms(cf, messageCount, stream); CountDownLatch latch = new CountDownLatch(messageCount); @@ -72,6 +96,7 @@ void offsetTypeFirstShouldStartConsumingFromBeginning(Map subscr Client client = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client1, subscriptionId, offset12, messageCount1, dataSize) -> client1.credit(subscriptionId, 1)) @@ -117,9 +142,9 @@ void amqpOffsetTypeFirstShouldStartConsumingFromBeginning() throws Exception { } @ParameterizedTest - @MethodSource("subscriptionProperties") - void offsetTypeLastShouldReturnLastChunk(Map subscriptionProperties) - throws Exception { + @MethodSource("offsetArguments") + void offsetTypeLastShouldReturnLastChunk( + Map subscriptionProperties, SslContext sslContext) throws Exception { int messageCount = 50000; long lastOffset = messageCount - 1; TestUtils.publishAndWaitForConfirms(cf, messageCount, stream); @@ -131,6 +156,7 @@ void offsetTypeLastShouldReturnLastChunk(Map subscriptionPropert Client client = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client1, subscriptionId, offset12, messageCount1, dataSize) -> { client1.credit(subscriptionId, 1); @@ -198,9 +224,9 @@ void amqpOffsetTypeLastShouldReturnLastChunk() throws Exception { } @ParameterizedTest - @MethodSource("subscriptionProperties") - void offsetTypeNextShouldReturnNewPublishedMessages(Map subscriptionProperties) - throws Exception { + @MethodSource("offsetArguments") + void offsetTypeNextShouldReturnNewPublishedMessages( + Map subscriptionProperties, SslContext sslContext) throws Exception { int firstWaveMessageCount = 50000; int secondWaveMessageCount = 20000; int lastOffset = firstWaveMessageCount + secondWaveMessageCount - 1; @@ -211,6 +237,7 @@ void offsetTypeNextShouldReturnNewPublishedMessages(Map subscrip Client client = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client1, subscriptionId, offset, messageCount1, dataSize) -> client1.credit(subscriptionId, 1)) @@ -266,9 +293,9 @@ void amqpOffsetTypeNextShouldReturnNewPublishedMessages() throws Exception { } @ParameterizedTest - @MethodSource("subscriptionProperties") - void offsetTypeOffsetShouldStartConsumingFromOffset(Map subscriptionProperties) - throws Exception { + @MethodSource("offsetArguments") + void offsetTypeOffsetShouldStartConsumingFromOffset( + Map subscriptionProperties, SslContext sslContext) throws Exception { int messageCount = 50000; TestUtils.publishAndWaitForConfirms(cf, messageCount, stream); int offset = messageCount / 10; @@ -278,6 +305,7 @@ void offsetTypeOffsetShouldStartConsumingFromOffset(Map subscrip Client client = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client1, subscriptionId, offset12, messageCount1, dataSize) -> client1.credit(subscriptionId, 1)) @@ -324,9 +352,9 @@ void amqpOffsetTypeOffsetShouldStartConsumingFromOffset() throws Exception { } @ParameterizedTest - @MethodSource("subscriptionProperties") + @MethodSource("offsetArguments") void offsetTypeTimestampShouldStartConsumingFromTimestamp( - Map subscriptionProperties) throws Exception { + Map subscriptionProperties, SslContext sslContext) throws Exception { int firstWaveMessageCount = 50000; int secondWaveMessageCount = 20000; int lastOffset = firstWaveMessageCount + secondWaveMessageCount - 1; @@ -342,6 +370,7 @@ void offsetTypeTimestampShouldStartConsumingFromTimestamp( Client client = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client1, subscriptionId, offset, messageCount1, dataSize) -> client1.credit(subscriptionId, 1)) @@ -442,8 +471,9 @@ void filterSmallerOffsets() throws Exception { } } - @Test - void consumeFromTail() throws Exception { + @ParameterizedTest + @MethodSource("sslContexts") + void consumeFromTail(SslContext sslContext) throws Exception { int messageCount = 10000; CountDownLatch firstWaveLatch = new CountDownLatch(messageCount); CountDownLatch secondWaveLatch = new CountDownLatch(messageCount * 2); @@ -474,6 +504,7 @@ void consumeFromTail() throws Exception { Client consumer = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client, subscriptionId, offset, messageCount1, dataSize) -> client.credit(subscriptionId, 1)) @@ -502,8 +533,10 @@ void consumeFromTail() throws Exception { .forEach(v -> assertThat(v).startsWith("second wave").doesNotStartWith("first wave")); } - @Test - void shouldReachTailWhenPublisherStopWhileConsumerIsBehind() throws Exception { + @ParameterizedTest + @MethodSource("sslContexts") + void shouldReachTailWhenPublisherStopWhileConsumerIsBehind(SslContext sslContext) + throws Exception { int messageCount = 100000; int messageLimit = messageCount * 2; AtomicLong lastConfirmed = new AtomicLong(); @@ -524,6 +557,7 @@ void shouldReachTailWhenPublisherStopWhileConsumerIsBehind() throws Exception { Client consumer = cf.get( new Client.ClientParameters() + .sslContext(sslContext) .chunkListener( (client, subscriptionId, offset, msgCount, dataSize) -> client.credit(b(0), 1)) .messageListener( diff --git a/src/test/java/com/rabbitmq/stream/impl/TestUtils.java b/src/test/java/com/rabbitmq/stream/impl/TestUtils.java index 840dfaa4aa..ce5834b49a 100644 --- a/src/test/java/com/rabbitmq/stream/impl/TestUtils.java +++ b/src/test/java/com/rabbitmq/stream/impl/TestUtils.java @@ -35,10 +35,15 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.lang.annotation.*; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; +import java.security.cert.X509Certificate; import java.time.Duration; import java.util.Collection; import java.util.Collections; @@ -56,10 +61,18 @@ import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.IntStream; +import javax.net.ssl.X509TrustManager; import org.assertj.core.api.AssertDelegateTarget; import org.assertj.core.api.Condition; import org.junit.jupiter.api.TestInfo; -import org.junit.jupiter.api.extension.*; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ConditionEvaluationResult; +import org.junit.jupiter.api.extension.ExecutionCondition; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.ExtensionContext; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -200,6 +213,132 @@ static Answer answer(Consumer invocation) { }; } + static void doIfNotNull(T obj, Consumer action) { + if (obj != null) { + action.accept(obj); + } + } + + static void declareSuperStreamTopology(Connection connection, String superStream, int partitions) + throws Exception { + declareSuperStreamTopology( + connection, + superStream, + IntStream.range(0, partitions).mapToObj(String::valueOf).toArray(String[]::new)); + } + + static void declareSuperStreamTopology( + Connection connection, String superStream, String... routingKeys) throws Exception { + try (Channel ch = connection.createChannel()) { + ch.exchangeDeclare(superStream, BuiltinExchangeType.DIRECT, true); + for (String routingKey : routingKeys) { + String partitionName = superStream + "-" + routingKey; + ch.queueDeclare( + partitionName, true, false, false, Collections.singletonMap("x-queue-type", "stream")); + // TODO consider adding some arguments to the bindings + // can be useful to identify a partition, e.g. partition number + ch.queueBind(partitionName, superStream, routingKey); + } + } + } + + static void deleteSuperStreamTopology(Connection connection, String superStream, int partitions) + throws Exception { + deleteSuperStreamTopology( + connection, + superStream, + IntStream.range(0, partitions).mapToObj(String::valueOf).toArray(String[]::new)); + } + + static void deleteSuperStreamTopology( + Connection connection, String superStream, String... routingKeys) throws Exception { + try (Channel ch = connection.createChannel()) { + ch.exchangeDelete(superStream); + for (String routingKey : routingKeys) { + String partitionName = superStream + "-" + routingKey; + ch.queueDelete(partitionName); + } + } + } + + public static String streamName(TestInfo info) { + return streamName(info.getTestClass().get(), info.getTestMethod().get()); + } + + private static String streamName(ExtensionContext context) { + return streamName(context.getTestInstance().get().getClass(), context.getTestMethod().get()); + } + + private static String streamName(Class testClass, Method testMethod) { + String uuid = UUID.randomUUID().toString(); + return String.format( + "%s_%s%s", + testClass.getSimpleName(), testMethod.getName(), uuid.substring(uuid.length() / 2)); + } + + static boolean tlsAvailable() { + if (Host.rabbitmqctlCommand() == null) { + throw new IllegalStateException( + "rabbitmqctl.bin system property not set, cannot check if MQTT plugin is enabled"); + } else { + try { + Process process = Host.rabbitmqctl("status"); + String output = capture(process.getInputStream()); + return output.contains("stream/ssl"); + } catch (Exception e) { + throw new RuntimeException("Error while trying to detect TLS: " + e.getMessage()); + } + } + } + + private static String capture(InputStream is) throws IOException { + BufferedReader br = new BufferedReader(new InputStreamReader(is)); + String line; + StringBuilder buff = new StringBuilder(); + while ((line = br.readLine()) != null) { + buff.append(line).append("\n"); + } + return buff.toString(); + } + + static void forEach(Collection in, CallableIndexConsumer consumer) throws Exception { + int count = 0; + for (T t : in) { + consumer.accept(count++, t); + } + } + + static CountDownLatchAssert latchAssert(CountDownLatch latch) { + return new CountDownLatchAssert(latch); + } + + static CountDownLatchAssert latchAssert(AtomicReference latchReference) { + return new CountDownLatchAssert(latchReference.get()); + } + + static Condition responseCode(short expectedResponseCode) { + String message = "expected code for stream exception is " + expectedResponseCode; + return new Condition<>( + throwable -> + throwable instanceof StreamException + && ((StreamException) throwable).getCode() == expectedResponseCode, + message); + } + + static Map metadata(String stream, Broker leader, List replicas) { + return metadata(stream, leader, replicas, Constants.RESPONSE_CODE_OK); + } + + static Map metadata( + String stream, Broker leader, List replicas, short code) { + return Collections.singletonMap( + stream, new Client.StreamMetadata(stream, code, leader, replicas)); + } + + static Map metadata(Broker leader, List replicas) { + return metadata("stream", leader, replicas); + } + @Target({ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented @@ -218,11 +357,27 @@ static Answer answer(Consumer invocation) { @ExtendWith(DisabledIfStompNotEnabledCondition.class) @interface DisabledIfStompNotEnabled {} + @Target({ElementType.TYPE, ElementType.METHOD}) + @Retention(RetentionPolicy.RUNTIME) + @Documented + @ExtendWith(DisabledIfTlsNotEnabledCondition.class) + @interface DisabledIfTlsNotEnabled {} + interface TaskWithException { void run(Object context) throws Exception; } + interface CallableIndexConsumer { + + void accept(int index, T t) throws Exception; + } + + interface CallableConsumer { + + void accept(T t) throws Exception; + } + public static class StreamTestInfrastructureExtension implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback { @@ -307,69 +462,6 @@ public void afterAll(ExtensionContext context) throws Exception { } } - static void doIfNotNull(T obj, Consumer action) { - if (obj != null) { - action.accept(obj); - } - } - - static void declareSuperStreamTopology(Connection connection, String superStream, int partitions) - throws Exception { - declareSuperStreamTopology( - connection, - superStream, - IntStream.range(0, partitions).mapToObj(String::valueOf).toArray(String[]::new)); - } - - static void declareSuperStreamTopology( - Connection connection, String superStream, String... routingKeys) throws Exception { - try (Channel ch = connection.createChannel()) { - ch.exchangeDeclare(superStream, BuiltinExchangeType.DIRECT, true); - for (String routingKey : routingKeys) { - String partitionName = superStream + "-" + routingKey; - ch.queueDeclare( - partitionName, true, false, false, Collections.singletonMap("x-queue-type", "stream")); - // TODO consider adding some arguments to the bindings - // can be useful to identify a partition, e.g. partition number - ch.queueBind(partitionName, superStream, routingKey); - } - } - } - - static void deleteSuperStreamTopology(Connection connection, String superStream, int partitions) - throws Exception { - deleteSuperStreamTopology( - connection, - superStream, - IntStream.range(0, partitions).mapToObj(String::valueOf).toArray(String[]::new)); - } - - static void deleteSuperStreamTopology( - Connection connection, String superStream, String... routingKeys) throws Exception { - try (Channel ch = connection.createChannel()) { - ch.exchangeDelete(superStream); - for (String routingKey : routingKeys) { - String partitionName = superStream + "-" + routingKey; - ch.queueDelete(partitionName); - } - } - } - - public static String streamName(TestInfo info) { - return streamName(info.getTestClass().get(), info.getTestMethod().get()); - } - - private static String streamName(ExtensionContext context) { - return streamName(context.getTestInstance().get().getClass(), context.getTestMethod().get()); - } - - private static String streamName(Class testClass, Method testMethod) { - String uuid = UUID.randomUUID().toString(); - return String.format( - "%s_%s%s", - testClass.getSimpleName(), testMethod.getName(), uuid.substring(uuid.length() / 2)); - } - public static class ClientFactory { private final EventLoopGroup eventLoopGroup; @@ -456,33 +548,18 @@ public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext con } } - private static String capture(InputStream is) throws IOException { - BufferedReader br = new BufferedReader(new InputStreamReader(is)); - String line; - StringBuilder buff = new StringBuilder(); - while ((line = br.readLine()) != null) { - buff.append(line).append("\n"); - } - return buff.toString(); - } + static class DisabledIfTlsNotEnabledCondition implements ExecutionCondition { - static void forEach(Collection in, CallableIndexConsumer consumer) throws Exception { - int count = 0; - for (T t : in) { - consumer.accept(count++, t); + @Override + public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext context) { + if (tlsAvailable()) { + return ConditionEvaluationResult.enabled("TLS is enabled"); + } else { + return ConditionEvaluationResult.disabled("TLS is disabled"); + } } } - interface CallableIndexConsumer { - - void accept(int index, T t) throws Exception; - } - - interface CallableConsumer { - - void accept(T t) throws Exception; - } - static class CountDownLatchAssert implements AssertDelegateTarget { private static final Duration TIMEOUT = Duration.ofSeconds(10); @@ -528,34 +605,16 @@ void doesNotComplete(Duration timeout) { } } - static CountDownLatchAssert latchAssert(CountDownLatch latch) { - return new CountDownLatchAssert(latch); - } - - static CountDownLatchAssert latchAssert(AtomicReference latchReference) { - return new CountDownLatchAssert(latchReference.get()); - } - - static Condition responseCode(short expectedResponseCode) { - String message = "expected code for stream exception is " + expectedResponseCode; - return new Condition<>( - throwable -> - throwable instanceof StreamException - && ((StreamException) throwable).getCode() == expectedResponseCode, - message); - } - - static Map metadata(String stream, Broker leader, List replicas) { - return metadata(stream, leader, replicas, Constants.RESPONSE_CODE_OK); - } + static class AlwaysTrustTrustManager implements X509TrustManager { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) {} - static Map metadata( - String stream, Broker leader, List replicas, short code) { - return Collections.singletonMap( - stream, new Client.StreamMetadata(stream, code, leader, replicas)); - } + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) {} - static Map metadata(Broker leader, List replicas) { - return metadata("stream", leader, replicas); + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } } } diff --git a/src/test/java/com/rabbitmq/stream/impl/TlsTest.java b/src/test/java/com/rabbitmq/stream/impl/TlsTest.java new file mode 100644 index 0000000000..c5c6a096fa --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/impl/TlsTest.java @@ -0,0 +1,101 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All rights reserved. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. + +package com.rabbitmq.stream.impl; + +import static com.rabbitmq.stream.impl.TestUtils.b; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +import com.rabbitmq.stream.OffsetSpecification; +import com.rabbitmq.stream.impl.TestUtils.AlwaysTrustTrustManager; +import com.rabbitmq.stream.impl.TestUtils.DisabledIfTlsNotEnabled; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import javax.net.ssl.SSLException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@DisabledIfTlsNotEnabled +@ExtendWith(TestUtils.StreamTestInfrastructureExtension.class) +public class TlsTest { + + String stream; + + TestUtils.ClientFactory cf; + int credit = 10; + + static SslContext sslContext() { + try { + return SslContextBuilder.forClient().trustManager(new AlwaysTrustTrustManager()).build(); + } catch (SSLException e) { + throw new RuntimeException(e); + } + } + + @Test + void publishAndConsume() throws Exception { + int publishCount = 1_000_000; + + CountDownLatch consumedLatch = new CountDownLatch(publishCount); + Client.ChunkListener chunkListener = + (client, correlationId, offset, messageCount, dataSize) -> { + if (consumedLatch.getCount() != 0) { + client.credit(correlationId, 1); + } + }; + + Client.MessageListener messageListener = (corr, offset, data) -> consumedLatch.countDown(); + + Client client = + cf.get( + new Client.ClientParameters() + .sslContext(sslContext()) + .chunkListener(chunkListener) + .messageListener(messageListener)); + + client.subscribe(b(1), stream, OffsetSpecification.first(), credit); + + CountDownLatch confirmedLatch = new CountDownLatch(publishCount); + new Thread( + () -> { + Client publisher = + cf.get( + new Client.ClientParameters() + .sslContext(sslContext()) + .publishConfirmListener( + (publisherId, correlationId) -> confirmedLatch.countDown())); + int messageId = 0; + publisher.declarePublisher(b(1), null, stream); + while (messageId < publishCount) { + messageId++; + publisher.publish( + b(1), + Collections.singletonList( + publisher + .messageBuilder() + .addData(("message" + messageId).getBytes(StandardCharsets.UTF_8)) + .build())); + } + }) + .start(); + + assertThat(confirmedLatch.await(15, SECONDS)).isTrue(); + assertThat(consumedLatch.await(15, SECONDS)).isTrue(); + client.unsubscribe(b(1)); + } +} diff --git a/src/test/resources/logback-test.xml b/src/test/resources/logback-test.xml index ab23c3a378..679e96797c 100644 --- a/src/test/resources/logback-test.xml +++ b/src/test/resources/logback-test.xml @@ -7,7 +7,6 @@ -