diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 72ad9316..2a4b7f90 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -19,6 +19,8 @@ jspecify = "1.0.0" mockito = "5.20.0" # https://mvnrepository.com/artifact/org.assertj/assertj-core assertj-core = "4.0.0-M1" +# https://mvnrepository.com/artifact/org.awaitility/awaitility +awaitility = "4.3.0" [libraries] project-reactor-core = { module = "io.projectreactor:reactor-core", version.ref = "project-reactor" } @@ -37,6 +39,7 @@ testcontainers = { module = "org.testcontainers:testcontainers", version.ref = " mockito-core = { module = "org.mockito:mockito-core", version.ref = "mockito" } mockito-junit-jupiter = { module = "org.mockito:mockito-junit-jupiter", version.ref = "mockito" } assertj-core = { module = "org.assertj:assertj-core", version.ref = "assertj-core" } +awaitility = { module = "org.awaitility:awaitility", version.ref = "awaitility" } [bundles] mail = ["jakarta-mail-api", "angus-mail"] diff --git a/rlib-common/src/test/java/javasabr/rlib/common/util/AwaitUtilsTest.java b/rlib-common/src/test/java/javasabr/rlib/common/util/AwaitUtilsTest.java deleted file mode 100644 index 92958d7b..00000000 --- a/rlib-common/src/test/java/javasabr/rlib/common/util/AwaitUtilsTest.java +++ /dev/null @@ -1,48 +0,0 @@ -package javasabr.rlib.common.util; - -import static org.assertj.core.api.Assertions.assertThat; - -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.jupiter.api.Test; - -/** - * Tests of {@link AwaitUtils} methods. - * - * @author crazyrokr - */ -public class AwaitUtilsTest { - - @Test - void shouldAwaitCondition() throws InterruptedException { - // given - var condition = new AtomicBoolean(false); - var thread = new Thread(() -> { - try { - Thread.sleep(100); - condition.set(true); - } catch (InterruptedException e) { - // ignore - } - }); - - // when - thread.start(); - boolean result = AwaitUtils.await(500, TimeUnit.MILLISECONDS, condition::get); - - // then - assertThat(result).isTrue(); - } - - @Test - void shouldTimeoutIfConditionNotMet() throws InterruptedException { - // given - var condition = new AtomicBoolean(false); - - // when - boolean result = AwaitUtils.await(100, TimeUnit.MILLISECONDS, condition::get); - - // then - assertThat(result).isFalse(); - } -} diff --git a/rlib-common/src/testFixtures/java/javasabr/rlib/common/util/AwaitUtils.java b/rlib-common/src/testFixtures/java/javasabr/rlib/common/util/AwaitUtils.java deleted file mode 100644 index 869b3d6e..00000000 --- a/rlib-common/src/testFixtures/java/javasabr/rlib/common/util/AwaitUtils.java +++ /dev/null @@ -1,38 +0,0 @@ -package javasabr.rlib.common.util; - -import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; -import lombok.experimental.UtilityClass; - -/** - * The utility class to await some conditions. - * - * @author crazyrokr - */ -@UtilityClass -public final class AwaitUtils { - - /** - * Await for the condition during the amount of time units. - * - * @param amount the amount of time units. - * @param unit the time unit. - * @param condition the condition. - * @return true if the condition was met. - * @throws InterruptedException if the current thread was interrupted. - */ - public static boolean await(long amount, TimeUnit unit, Supplier condition) throws InterruptedException { - if (condition.get()) { - return true; - } - var timeoutMillis = unit.toMillis(amount); - var endTime = System.currentTimeMillis() + timeoutMillis; - while (System.currentTimeMillis() < endTime) { - if (condition.get()) { - return true; - } - Thread.sleep(Math.clamp(endTime - System.currentTimeMillis(), 1, 10)); - } - return condition.get(); - } -} diff --git a/rlib-network/build.gradle b/rlib-network/build.gradle index c8763542..79bd2819 100644 --- a/rlib-network/build.gradle +++ b/rlib-network/build.gradle @@ -12,4 +12,5 @@ dependencies { testRuntimeOnly projects.rlibLoggerImpl loadTestRuntimeOnly projects.rlibLoggerImpl testImplementation testFixtures(projects.rlibCommon) + testImplementation libs.awaitility } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java index 5b7033d8..6154547a 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java @@ -160,6 +160,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) { case NEED_WRAP: { log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted); packetWriter.accept(SslWrapRequestNetworkPacket.getInstance()); + if (networkBuffer.hasRemaining()) { + return decryptAndRead(networkBuffer); + } NetworkUtils.cleanNetworkBuffer(networkBuffer); return SKIP_READ_PACKETS; } @@ -204,6 +207,11 @@ protected int decryptAndRead(ByteBuffer receivedBuffer) { } switch (result.getStatus()) { case OK: { + if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) { + log.debug(remoteAddress, "[%s] No progress during decryption, skip read packets"::formatted); + NetworkUtils.cleanNetworkBuffer(receivedBuffer); + return SKIP_READ_PACKETS; + } sslDataBuffer.flip(); logDataAfterDecrypt(remoteAddress, sslDataBuffer); total += readPackets(sslDataBuffer, sslDataPendingBuffer); diff --git a/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java index e831d302..35c21cc1 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java @@ -1,13 +1,14 @@ package javasabr.rlib.network; +import static java.util.function.Predicate.isEqual; import static javasabr.rlib.network.util.NetworkUtils.createAllTrustedClientSslContext; import static javasabr.rlib.network.util.NetworkUtils.createSslContext; import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; import java.net.InetSocketAddress; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import javasabr.rlib.common.util.AwaitUtils; import javasabr.rlib.network.exception.ConnectionClosedException; import javasabr.rlib.network.impl.AbstractConnection; import javasabr.rlib.network.impl.DefaultConnection; @@ -80,14 +81,16 @@ void shouldCloseServerConnectionWhenClientClosesTcpChannelAbruptly() { // when clientConnection.channel().close(); - assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, clientConnection::closed)) - .as("Client connection should be closed prior server side verification") - .isTrue(); // then - assertThat(AwaitUtils.await(5, TimeUnit.SECONDS, serverConnection::closed)) - .as("Server connection should be closed after receiving EOF from abruptly closed client channel") - .isTrue(); + await() + .alias("Client connection should be closed prior server side verification") + .atMost(5, TimeUnit.SECONDS) + .until(clientConnection::closed, isEqual(true)); + await() + .alias("Server connection should be closed after receiving EOF from abruptly closed client channel") + .atMost(5, TimeUnit.SECONDS) + .until(serverConnection::closed, isEqual(true)); } } } diff --git a/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReaderTest.java b/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReaderTest.java new file mode 100644 index 00000000..dc48da0b --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReaderTest.java @@ -0,0 +1,161 @@ +package javasabr.rlib.network.packet.impl; + +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.function.Consumer; +import javasabr.rlib.network.BufferAllocator; +import javasabr.rlib.network.Network; +import javasabr.rlib.network.NetworkConfig; +import javasabr.rlib.network.UnsafeConnection; +import javasabr.rlib.network.impl.DefaultBufferAllocator; +import javasabr.rlib.network.packet.ReadableNetworkPacket; +import javasabr.rlib.network.packet.WritableNetworkPacket; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLSession; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +/** + * The tests of SSL packet reader + * + * @author crazyrokr + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class AbstractSslNetworkPacketReaderTest { + + @Mock + private TestConnection connection; + + @Mock + private Network network; + + @Mock + private SSLEngine sslEngine; + + @Mock + private SSLSession sslSession; + + @Mock + private Consumer> packetHandler; + + @Mock + private Consumer> packetWriter; + + DefaultSslNetworkPacketReader, TestConnection> reader; + + private BufferAllocator bufferAllocator; + + private interface TestConnection extends UnsafeConnection {} + + @BeforeEach + void setUp() { + bufferAllocator = new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT); + when(connection.bufferAllocator()).thenReturn(bufferAllocator); + when(connection.network()).thenReturn((Network) network); + when(connection.remoteAddress()).thenReturn("test-address"); + when(network.config()).thenReturn(NetworkConfig.DEFAULT_CLIENT); + when(sslEngine.getSession()).thenReturn(sslSession); + when(sslSession.getApplicationBufferSize()).thenReturn(1024); + when(sslSession.getPacketBufferSize()).thenReturn(1024); + reader = spy(new DefaultSslNetworkPacketReader, TestConnection>( + connection, + () -> {}, + packetHandler, + packetHandler, + len -> mock(ReadableNetworkPacket.class), + sslEngine, + packetWriter, + 1, + 100) { + }); + doReturn(1).when(reader).readFullPacketLength(any(ByteBuffer.class)); + } + + @Test + void shouldNotLoseDataOnNeedWrapDuringHandshake() throws Exception { + // given + // Initial state: NEED_UNWRAP + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_UNWRAP); + + // First unwrap will result in NEED_WRAP and status OK, consuming some data. + // Simulate a single network buffer containing 5 bytes of handshake data followed by + // 5 bytes of application data, so the remaining bytes can still be processed afterward. + ByteBuffer networkData = ByteBuffer.allocate(10); + networkData.put(new byte[10]); + networkData.flip(); + + // doHandshake calls unwrap in NEED_UNWRAP, consumes first 5 bytes, then returns OK + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer[].class))).thenAnswer(invocation -> { + ByteBuffer in = invocation.getArgument(0); + in.position(in.position() + 5); // consume 5 bytes of handshake + // Change status to NEED_WRAP for next getHandshakeStatus() call + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP); + return new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 5, 0); + }); + + // decryptAndRead calls unwrap, consumes the remaining 5 bytes, then return FINISHED or NOT_HANDSHAKING + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenAnswer(invocation -> { + ByteBuffer in = invocation.getArgument(0); + ByteBuffer out = invocation.getArgument(1); + int remaining = in.remaining(); + in.position(in.limit()); // consume all + out.put(new byte[remaining]); // put decrypted data (mocked) + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NOT_HANDSHAKING); + return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, remaining, remaining); + }); + + // when + reader.readPackets(networkData); + + // then + // readPackets should have been called for the remaining 5 bytes, + // since each packet is 1 byte, it should have read 5 packets + verify(reader, times(5)).createPacketFor(any(ByteBuffer.class), anyInt(), anyInt(), anyInt()); + verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class)); + } + + @Test + void testShouldNotDeadLoopWhenNeedWrapAndNoProgress() throws Exception { + // given + // Initial state: NEED_WRAP + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP); + + // Network buffer has data + ByteBuffer networkData = ByteBuffer.allocate(10); + networkData.put(new byte[10]); + networkData.flip(); + + // Mock unwrap in decryptAndRead to return OK with 0 progress + // This happens if engine is in NEED_WRAP and can't decrypt application data + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))) + .thenReturn(new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 0, 0)); + + // when + // We expect this NOT to hang indefinitely. + // If it dead-loops, the test will fail by timeout. + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> reader.readPackets(networkData)); + + // then + // Should have requested wrap + verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class)); + } +}