From 7838321b74276e45b92c54904ea31ef70ed9e33f Mon Sep 17 00:00:00 2001 From: Alan Bateman Date: Wed, 4 Jun 2025 09:52:45 +0000 Subject: [PATCH] 8358496: Concurrent reading from Socket with timeout executes sequentially Reviewed-by: dfuchs --- .../classes/sun/nio/ch/NioSocketImpl.java | 23 +- test/jdk/java/net/Socket/Timeouts.java | 246 +++++++++++------- 2 files changed, 163 insertions(+), 106 deletions(-) diff --git a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java index 6705134648d..dd81b356738 100644 --- a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java +++ b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java @@ -288,7 +288,7 @@ public final class NioSocketImpl extends SocketImpl implements PlatformSocketImp * @throws SocketException if the socket is closed or a socket I/O error occurs * @throws SocketTimeoutException if the read timeout elapses */ - private int implRead(byte[] b, int off, int len) throws IOException { + private int implRead(byte[] b, int off, int len, long remainingNanos) throws IOException { int n = 0; FileDescriptor fd = beginRead(); try { @@ -296,11 +296,10 @@ public final class NioSocketImpl extends SocketImpl implements PlatformSocketImp throw new SocketException("Connection reset"); if (isInputClosed) return -1; - int timeout = this.timeout; - configureNonBlockingIfNeeded(fd, timeout > 0); - if (timeout > 0) { + configureNonBlockingIfNeeded(fd, remainingNanos > 0); + if (remainingNanos > 0) { // read with timeout - n = timedRead(fd, b, off, len, MILLISECONDS.toNanos(timeout)); + n = timedRead(fd, b, off, len, remainingNanos); } else { // read, no timeout n = tryRead(fd, b, off, len); @@ -335,14 +334,24 @@ public final class NioSocketImpl extends SocketImpl implements PlatformSocketImp if (len == 0) { return 0; } else { - readLock.lock(); + long remainingNanos = 0; + int timeout = this.timeout; + if (timeout > 0) { + remainingNanos = tryLock(readLock, timeout, MILLISECONDS); + if (remainingNanos <= 0) { + assert !readLock.isHeldByCurrentThread(); + throw new SocketTimeoutException("Read timed out"); + } + } else { + readLock.lock(); + } try { // emulate legacy behavior to return -1, even if socket is closed if (readEOF) return -1; // read up to MAX_BUFFER_SIZE bytes int size = Math.min(len, MAX_BUFFER_SIZE); - int n = implRead(b, off, size); + int n = implRead(b, off, size, remainingNanos); if (n == -1) readEOF = true; return n; diff --git a/test/jdk/java/net/Socket/Timeouts.java b/test/jdk/java/net/Socket/Timeouts.java index 83bf01ebf50..f8fcfb86d0f 100644 --- a/test/jdk/java/net/Socket/Timeouts.java +++ b/test/jdk/java/net/Socket/Timeouts.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -23,11 +23,10 @@ /* * @test - * @bug 8221481 + * @bug 8221481 8358496 * @library /test/lib * @build jdk.test.lib.Utils - * @compile Timeouts.java - * @run testng/othervm/timeout=180 Timeouts + * @run junit/othervm/timeout=180 Timeouts * @summary Test Socket timeouts */ @@ -43,25 +42,27 @@ import java.net.Socket; import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; +import java.util.ArrayList; +import java.util.concurrent.Callable; import java.util.concurrent.Executors; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import org.testng.SkipException; -import org.testng.annotations.Test; -import static org.testng.Assert.*; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.*; import jdk.test.lib.Utils; -@Test -public class Timeouts { +class Timeouts { /** - * Test timed connect where connection is established + * Test timed connect where connection is established. */ - public void testTimedConnect1() throws IOException { + @Test + void testTimedConnect1() throws IOException { try (ServerSocket ss = boundServerSocket()) { try (Socket s = new Socket()) { s.connect(ss.getLocalSocketAddress(), 2000); @@ -70,21 +71,21 @@ public class Timeouts { } /** - * Test timed connect where connection is refused + * Test timed connect where connection is refused. */ - public void testTimedConnect2() throws IOException { + @Test + void testTimedConnect2() throws IOException { try (Socket s = new Socket()) { SocketAddress remote = Utils.refusingEndpoint(); - try { - s.connect(remote, 10000); - } catch (ConnectException expected) { } + assertThrows(ConnectException.class, () -> s.connect(remote, 10000)); } } /** - * Test connect with a timeout of Integer.MAX_VALUE + * Test connect with a timeout of Integer.MAX_VALUE. */ - public void testTimedConnect3() throws IOException { + @Test + void testTimedConnect3() throws IOException { try (ServerSocket ss = boundServerSocket()) { try (Socket s = new Socket()) { s.connect(ss.getLocalSocketAddress(), Integer.MAX_VALUE); @@ -95,141 +96,183 @@ public class Timeouts { /** * Test connect with a negative timeout. */ - public void testTimedConnect4() throws IOException { + @Test + void testTimedConnect4() throws IOException { try (ServerSocket ss = boundServerSocket()) { try (Socket s = new Socket()) { - expectThrows(IllegalArgumentException.class, + assertThrows(IllegalArgumentException.class, () -> s.connect(ss.getLocalSocketAddress(), -1)); } } } /** - * Test timed read where the read succeeds immediately + * Test timed read where the read succeeds immediately. */ - public void testTimedRead1() throws IOException { + @Test + void testTimedRead1() throws IOException { withConnection((s1, s2) -> { s1.getOutputStream().write(99); s2.setSoTimeout(30*1000); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } /** - * Test timed read where the read succeeds after a delay + * Test timed read where the read succeeds after a delay. */ - public void testTimedRead2() throws IOException { + @Test + void testTimedRead2() throws IOException { withConnection((s1, s2) -> { scheduleWrite(s1.getOutputStream(), 99, 2000); s2.setSoTimeout(30*1000); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } /** - * Test timed read where the read times out + * Test timed read where the read times out. */ - public void testTimedRead3() throws IOException { + @Test + void testTimedRead3() throws IOException { withConnection((s1, s2) -> { s2.setSoTimeout(2000); long startMillis = millisTime(); - expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); + assertThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); int timeout = s2.getSoTimeout(); checkDuration(startMillis, timeout-100, timeout+20_000); }); } /** - * Test timed read that succeeds after a previous read has timed out + * Test timed read that succeeds after a previous read has timed out. */ - public void testTimedRead4() throws IOException { + @Test + void testTimedRead4() throws IOException { withConnection((s1, s2) -> { s2.setSoTimeout(2000); - expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); + assertThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); s1.getOutputStream().write(99); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } /** * Test timed read that succeeds after a previous read has timed out and - * after a short delay + * after a short delay. */ - public void testTimedRead5() throws IOException { + @Test + void testTimedRead5() throws IOException { withConnection((s1, s2) -> { s2.setSoTimeout(2000); - expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); + assertThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); s2.setSoTimeout(30*3000); scheduleWrite(s1.getOutputStream(), 99, 2000); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } /** - * Test untimed read that succeeds after a previous read has timed out + * Test untimed read that succeeds after a previous read has timed out. */ - public void testTimedRead6() throws IOException { + @Test + void testTimedRead6() throws IOException { withConnection((s1, s2) -> { s2.setSoTimeout(2000); - expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); + assertThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); s1.getOutputStream().write(99); s2.setSoTimeout(0); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } /** * Test untimed read that succeeds after a previous read has timed out and - * after a short delay + * after a short delay. */ - public void testTimedRead7() throws IOException { + @Test + void testTimedRead7() throws IOException { withConnection((s1, s2) -> { s2.setSoTimeout(2000); - expectThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); + assertThrows(SocketTimeoutException.class, () -> s2.getInputStream().read()); scheduleWrite(s1.getOutputStream(), 99, 2000); s2.setSoTimeout(0); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } /** - * Test async close of timed read + * Test async close of timed read. */ - public void testTimedRead8() throws IOException { + @Test + void testTimedRead8() throws IOException { withConnection((s1, s2) -> { s2.setSoTimeout(30*1000); scheduleClose(s2, 2000); - expectThrows(SocketException.class, () -> s2.getInputStream().read()); + assertThrows(SocketException.class, () -> s2.getInputStream().read()); }); } /** - * Test read with a timeout of Integer.MAX_VALUE + * Test read with a timeout of Integer.MAX_VALUE. */ - public void testTimedRead9() throws IOException { + @Test + void testTimedRead9() throws IOException { withConnection((s1, s2) -> { scheduleWrite(s1.getOutputStream(), 99, 2000); s2.setSoTimeout(Integer.MAX_VALUE); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); }); } + /** + * Test 100 threads concurrently reading the same Socket with a timeout of 2s. + * Each read should throw SocketTimeoutException after 2s, not 2s for the first, + * 4s for the second, 6s for the third, up to 200s for the last thread. + */ + @Test + void testTimedRead10() throws Exception { + var futures = new ArrayList>(); + withConnection((_, s) -> { + s.setSoTimeout(2000); + Callable timedReadTask = () -> { + long startMillis = millisTime(); + assertThrows(SocketTimeoutException.class, + () -> s.getInputStream().read()); + int timeout = s.getSoTimeout(); + checkDuration(startMillis, timeout-100, timeout+20_000); + return null; + }; + // start 100 virtual threads to read from the socket + try (var executor = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < 100; i++) { + Future future = executor.submit(timedReadTask); + futures.add(future); + } + } + }); + for (Future future : futures) { + future.get(); + } + } + /** * Test writing after a timed read. */ - public void testTimedWrite1() throws IOException { + @Test + void testTimedWrite1() throws IOException { withConnection((s1, s2) -> { s1.getOutputStream().write(99); s2.setSoTimeout(3000); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); // schedule thread to read s1 to EOF scheduleReadToEOF(s1.getInputStream(), 3000); @@ -245,12 +288,13 @@ public class Timeouts { /** * Test async close of writer (after a timed read). */ - public void testTimedWrite2() throws IOException { + @Test + void testTimedWrite2() throws IOException { withConnection((s1, s2) -> { s1.getOutputStream().write(99); s2.setSoTimeout(3000); int b = s2.getInputStream().read(); - assertTrue(b == 99); + assertEquals(99, b); // schedule s2 to be closed scheduleClose(s2, 3000); @@ -266,9 +310,10 @@ public class Timeouts { } /** - * Test timed accept where a connection is established immediately + * Test timed accept where a connection is established immediately. */ - public void testTimedAccept1() throws IOException { + @Test + void testTimedAccept1() throws IOException { Socket s1 = null; Socket s2 = null; try (ServerSocket ss = boundServerSocket()) { @@ -283,9 +328,10 @@ public class Timeouts { } /** - * Test timed accept where a connection is established after a short delay + * Test timed accept where a connection is established after a short delay. */ - public void testTimedAccept2() throws IOException { + @Test + void testTimedAccept2() throws IOException { try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(30*1000); scheduleConnect(ss.getLocalSocketAddress(), 2000); @@ -297,7 +343,8 @@ public class Timeouts { /** * Test timed accept where the accept times out */ - public void testTimedAccept3() throws IOException { + @Test + void testTimedAccept3() throws IOException { try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(2000); long startMillis = millisTime(); @@ -316,7 +363,8 @@ public class Timeouts { * Test timed accept where a connection is established immediately after a * previous accept timed out. */ - public void testTimedAccept4() throws IOException { + @Test + void testTimedAccept4() throws IOException { try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(2000); try { @@ -334,9 +382,10 @@ public class Timeouts { /** * Test untimed accept where a connection is established after a previous - * accept timed out + * accept timed out. */ - public void testTimedAccept5() throws IOException { + @Test + void testTimedAccept5() throws IOException { try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(2000); try { @@ -355,9 +404,10 @@ public class Timeouts { /** * Test untimed accept where a connection is established after a previous - * accept timed out and after a short delay + * accept timed out and after a short delay. */ - public void testTimedAccept6() throws IOException { + @Test + void testTimedAccept6() throws IOException { try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(2000); try { @@ -373,9 +423,10 @@ public class Timeouts { } /** - * Test async close of a timed accept + * Test async close of a timed accept. */ - public void testTimedAccept7() throws IOException { + @Test + void testTimedAccept7() throws IOException { try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(30*1000); long delay = 2000; @@ -393,9 +444,9 @@ public class Timeouts { /** * Test timed accept with the thread interrupt status set. */ - public void testTimedAccept8() throws IOException { - if (Thread.currentThread().isVirtual()) - throw new SkipException("Main test is a virtual thread"); + @Test + void testTimedAccept8() throws IOException { + assumeFalse(Thread.currentThread().isVirtual(), "Main test is a virtual thread"); try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(2000); Thread.currentThread().interrupt(); @@ -418,9 +469,9 @@ public class Timeouts { /** * Test interrupt of thread blocked in timed accept. */ - public void testTimedAccept9() throws IOException { - if (Thread.currentThread().isVirtual()) - throw new SkipException("Main test is a virtual thread"); + @Test + void testTimedAccept9() throws IOException { + assumeFalse(Thread.currentThread().isVirtual(), "Main test is a virtual thread"); try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(4000); // interrupt thread after 1 second @@ -445,7 +496,8 @@ public class Timeouts { /** * Test two threads blocked in timed accept where no connection is established. */ - public void testTimedAccept10() throws Exception { + @Test + void testTimedAccept10() throws Exception { ExecutorService pool = Executors.newFixedThreadPool(2); try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(4000); @@ -456,9 +508,9 @@ public class Timeouts { Future result2 = pool.submit(ss::accept); // both tasks should complete with SocketTimeoutException - Throwable e = expectThrows(ExecutionException.class, result1::get); + Throwable e = assertThrows(ExecutionException.class, result1::get); assertTrue(e.getCause() instanceof SocketTimeoutException); - e = expectThrows(ExecutionException.class, result2::get); + e = assertThrows(ExecutionException.class, result2::get); assertTrue(e.getCause() instanceof SocketTimeoutException); // should get here in 4 seconds, not 8 seconds @@ -472,7 +524,8 @@ public class Timeouts { /** * Test two threads blocked in timed accept where one connection is established. */ - public void testTimedAccept11() throws Exception { + @Test + void testTimedAccept11() throws Exception { ExecutorService pool = Executors.newFixedThreadPool(2); try (ServerSocket ss = boundServerSocket()) { ss.setSoTimeout(4000); @@ -514,25 +567,25 @@ public class Timeouts { /** * Test Socket setSoTimeout with a negative timeout. */ - @Test(expectedExceptions = { IllegalArgumentException.class }) - public void testBadTimeout1() throws IOException { + @Test + void testBadTimeout1() throws IOException { try (Socket s = new Socket()) { - s.setSoTimeout(-1); + assertThrows(IllegalArgumentException.class, () -> s.setSoTimeout(-1)); } } /** * Test ServerSocket setSoTimeout with a negative timeout. */ - @Test(expectedExceptions = { IllegalArgumentException.class }) - public void testBadTimeout2() throws IOException { + @Test + void testBadTimeout2() throws IOException { try (ServerSocket ss = new ServerSocket()) { - ss.setSoTimeout(-1); + assertThrows(IllegalArgumentException.class, () -> ss.setSoTimeout(-1)); } } /** - * Returns a ServerSocket bound to a port on the loopback address + * Returns a ServerSocket bound to a port on the loopback address. */ static ServerSocket boundServerSocket() throws IOException { var loopback = InetAddress.getLoopbackAddress(); @@ -542,14 +595,14 @@ public class Timeouts { } /** - * An operation that accepts two arguments and may throw IOException + * An operation that accepts two arguments and may throw IOException. */ interface ThrowingBiConsumer { void accept(T t, U u) throws IOException; } /** - * Invokes the consumer with a connected pair of sockets + * Invokes the consumer with a connected pair of sockets. */ static void withConnection(ThrowingBiConsumer consumer) throws IOException @@ -568,7 +621,7 @@ public class Timeouts { } /** - * Schedule c to be closed after a delay + * Schedule c to be closed after a delay. */ static void scheduleClose(Closeable c, long delay) { schedule(() -> { @@ -579,14 +632,14 @@ public class Timeouts { } /** - * Schedule thread to be interrupted after a delay + * Schedule thread to be interrupted after a delay. */ static Future scheduleInterrupt(Thread thread, long delay) { return schedule(() -> thread.interrupt(), delay); } /** - * Schedule a thread to connect to the given end point after a delay + * Schedule a thread to connect to the given end point after a delay. */ static void scheduleConnect(SocketAddress remote, long delay) { schedule(() -> { @@ -597,7 +650,7 @@ public class Timeouts { } /** - * Schedule a thread to read to EOF after a delay + * Schedule a thread to read to EOF after a delay. */ static void scheduleReadToEOF(InputStream in, long delay) { schedule(() -> { @@ -609,7 +662,7 @@ public class Timeouts { } /** - * Schedule a thread to write after a delay + * Schedule a thread to write after a delay. */ static void scheduleWrite(OutputStream out, byte[] data, long delay) { schedule(() -> { @@ -623,12 +676,7 @@ public class Timeouts { } static Future schedule(Runnable task, long delay) { - ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); - try { - return executor.schedule(task, delay, TimeUnit.MILLISECONDS); - } finally { - executor.shutdown(); - } + return ForkJoinPool.commonPool().schedule(task, delay, TimeUnit.MILLISECONDS); } /** @@ -640,7 +688,7 @@ public class Timeouts { } /** - * Check the duration of a task + * Check the duration of a task. * @param start start time, in milliseconds * @param min minimum expected duration, in milliseconds * @param max maximum expected duration, in milliseconds