diff --git a/test/jdk/java/util/concurrent/ThreadPerTaskExecutor/ThreadPerTaskExecutorTest.java b/test/jdk/java/util/concurrent/ThreadPerTaskExecutor/ThreadPerTaskExecutorTest.java index d57cacb7001..a599ef234dd 100644 --- a/test/jdk/java/util/concurrent/ThreadPerTaskExecutor/ThreadPerTaskExecutorTest.java +++ b/test/jdk/java/util/concurrent/ThreadPerTaskExecutor/ThreadPerTaskExecutorTest.java @@ -44,6 +44,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.lang.Thread.State.*; import static java.util.concurrent.Future.State.*; import org.junit.jupiter.api.Test; @@ -54,12 +55,6 @@ import org.junit.jupiter.params.provider.MethodSource; import static org.junit.jupiter.api.Assertions.*; class ThreadPerTaskExecutorTest { - // long running interruptible task - private static final Callable SLEEP_FOR_A_DAY = () -> { - Thread.sleep(Duration.ofDays(1)); - return null; - }; - private static ScheduledExecutorService scheduler; private static List threadFactories; @@ -92,14 +87,6 @@ class ThreadPerTaskExecutorTest { .map(f -> Executors.newThreadPerTaskExecutor(f)); } - /** - * Schedules a thread to be interrupted after the given delay. - */ - private void scheduleInterrupt(Thread thread, Duration delay) { - long millis = delay.toMillis(); - scheduler.schedule(thread::interrupt, millis, TimeUnit.MILLISECONDS); - } - /** * Test that a thread is created for each task. */ @@ -164,14 +151,14 @@ class ThreadPerTaskExecutorTest { assertFalse(executor.isTerminated()); assertFalse(executor.awaitTermination(10, TimeUnit.MILLISECONDS)); - Future result = executor.submit(SLEEP_FOR_A_DAY); + Future future = executor.submit(new LongRunningTask()); try { executor.shutdown(); assertTrue(executor.isShutdown()); assertFalse(executor.isTerminated()); assertFalse(executor.awaitTermination(500, TimeUnit.MILLISECONDS)); } finally { - result.cancel(true); + future.cancel(true); // interrupt task } } } @@ -187,19 +174,22 @@ class ThreadPerTaskExecutorTest { assertFalse(executor.isTerminated()); assertFalse(executor.awaitTermination(10, TimeUnit.MILLISECONDS)); - Future result = executor.submit(SLEEP_FOR_A_DAY); + var task = new LongRunningTask(); + Future future = executor.submit(task); try { + task.awaitStarted(); + List tasks = executor.shutdownNow(); assertTrue(executor.isShutdown()); assertTrue(tasks.isEmpty()); - Throwable e = assertThrows(ExecutionException.class, result::get); + Throwable e = assertThrows(ExecutionException.class, future::get); assertTrue(e.getCause() instanceof InterruptedException); assertTrue(executor.awaitTermination(3, TimeUnit.SECONDS)); assertTrue(executor.isTerminated()); } finally { - result.cancel(true); + future.cancel(true); } } } @@ -236,23 +226,25 @@ class ThreadPerTaskExecutorTest { } /** - * Invoke close with interrupt status set, should cancel task. + * Invoke close with interrupt status set. */ @ParameterizedTest @MethodSource("executors") void testClose3(ExecutorService executor) throws Exception { - Future future; + Future future; try (executor) { - future = executor.submit(SLEEP_FOR_A_DAY); + var task = new LongRunningTask(); + future = executor.submit(task); + task.awaitStarted(); Thread.currentThread().interrupt(); } finally { assertTrue(Thread.interrupted()); // clear interrupt } - assertTrue(executor.isShutdown()); assertTrue(executor.isTerminated()); - assertTrue(executor.awaitTermination(10, TimeUnit.MILLISECONDS)); - assertThrows(ExecutionException.class, future::get); + assertTrue(executor.awaitTermination(10, TimeUnit.MILLISECONDS)); + Throwable e = assertThrows(ExecutionException.class, future::get); + assertTrue(e.getCause() instanceof InterruptedException); } /** @@ -261,17 +253,20 @@ class ThreadPerTaskExecutorTest { @ParameterizedTest @MethodSource("executors") void testClose4(ExecutorService executor) throws Exception { - Future future; + Future future; try (executor) { - future = executor.submit(SLEEP_FOR_A_DAY); - scheduleInterrupt(Thread.currentThread(), Duration.ofMillis(500)); + var task = new LongRunningTask(); + future = executor.submit(task); + task.awaitStarted(); + scheduleInterruptAt("java.util.concurrent.ThreadPerTaskExecutor.close"); } finally { assertTrue(Thread.interrupted()); } assertTrue(executor.isShutdown()); assertTrue(executor.isTerminated()); - assertTrue(executor.awaitTermination(10, TimeUnit.MILLISECONDS)); - assertThrows(ExecutionException.class, future::get); + assertTrue(executor.awaitTermination(10, TimeUnit.MILLISECONDS)); + Throwable e = assertThrows(ExecutionException.class, future::get); + assertTrue(e.getCause() instanceof InterruptedException); } /** @@ -302,14 +297,14 @@ class ThreadPerTaskExecutorTest { @MethodSource("executors") void testAwaitTermination2(ExecutorService executor) throws Exception { Phaser barrier = new Phaser(2); - Future result = executor.submit(barrier::arriveAndAwaitAdvance); + Future future = executor.submit(barrier::arriveAndAwaitAdvance); try { executor.shutdown(); assertFalse(executor.awaitTermination(100, TimeUnit.MILLISECONDS)); barrier.arriveAndAwaitAdvance(); assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS)); } finally { - result.cancel(true); + future.cancel(true); } } @@ -384,28 +379,15 @@ class ThreadPerTaskExecutorTest { @MethodSource("executors") void testInvokeAny2(ExecutorService executor) throws Exception { try (executor) { - AtomicBoolean task2Started = new AtomicBoolean(); - AtomicReference task2Exception = new AtomicReference<>(); Callable task1 = () -> "foo"; - Callable task2 = () -> { - task2Started.set(true); - try { - Thread.sleep(Duration.ofDays(1)); - } catch (Exception e) { - task2Exception.set(e); - } - return "bar"; - }; + var task2 = new LongRunningTask(); String result = executor.invokeAny(Set.of(task1, task2)); assertTrue("foo".equals(result)); - // if task2 started then the sleep should have been interrupted - if (task2Started.get()) { - Throwable exc; - while ((exc = task2Exception.get()) == null) { - Thread.sleep(20); - } - assertTrue(exc instanceof InterruptedException); + // if task2 started then it should be interrupted + if (task2.isStarted()) { + task2.awaitDone(); + assertTrue(task2.isInterrupted()); } } } @@ -510,28 +492,15 @@ class ThreadPerTaskExecutorTest { @MethodSource("executors") void testInvokeAnyWithTimeout2(ExecutorService executor) throws Exception { try (executor) { - AtomicBoolean task2Started = new AtomicBoolean(); - AtomicReference task2Exception = new AtomicReference<>(); Callable task1 = () -> "foo"; - Callable task2 = () -> { - task2Started.set(true); - try { - Thread.sleep(Duration.ofDays(1)); - } catch (Exception e) { - task2Exception.set(e); - } - return "bar"; - }; + var task2 = new LongRunningTask(); String result = executor.invokeAny(Set.of(task1, task2), 1, TimeUnit.MINUTES); assertTrue("foo".equals(result)); - // if task2 started then the sleep should have been interrupted - if (task2Started.get()) { - Throwable exc; - while ((exc = task2Exception.get()) == null) { - Thread.sleep(20); - } - assertTrue(exc instanceof InterruptedException); + // if task2 started then it should be interrupted + if (task2.isStarted()) { + task2.awaitDone(); + assertTrue(task2.isInterrupted()); } } } @@ -603,20 +572,19 @@ class ThreadPerTaskExecutorTest { @MethodSource("executors") void testInterruptInvokeAny(ExecutorService executor) throws Exception { try (executor) { - Callable task1 = () -> { - Thread.sleep(Duration.ofMinutes(1)); - return "foo"; - }; - Callable task2 = () -> { - Thread.sleep(Duration.ofMinutes(2)); - return "bar"; - }; - scheduleInterrupt(Thread.currentThread(), Duration.ofMillis(500)); + var task = new LongRunningTask(); try { - executor.invokeAny(Set.of(task1, task2)); + scheduleInterruptAt("java.util.concurrent.ThreadPerTaskExecutor.invokeAny"); + executor.invokeAny(Set.of(task)); fail("invokeAny did not throw"); } catch (InterruptedException expected) { assertFalse(Thread.currentThread().isInterrupted()); + + // if task started then it should be interrupted + if (task.isStarted()) { + task.awaitDone(); + assertTrue(task.isInterrupted()); + } } finally { Thread.interrupted(); // clear interrupt } @@ -806,7 +774,7 @@ class ThreadPerTaskExecutorTest { } /** - * Test invokeAll with interrupt status set. + * Test untimed-invokeAll with interrupt status set. */ @ParameterizedTest @MethodSource("executors") @@ -817,7 +785,6 @@ class ThreadPerTaskExecutorTest { Thread.sleep(Duration.ofMinutes(1)); return "bar"; }; - Thread.currentThread().interrupt(); try { executor.invokeAll(List.of(task1, task2)); @@ -856,26 +823,25 @@ class ThreadPerTaskExecutorTest { } /** - * Test interrupt with thread blocked in invokeAll. + * Test interrupt of thread blocked in untimed-invokeAll. */ @ParameterizedTest @MethodSource("executors") void testInvokeAllInterrupt4(ExecutorService executor) throws Exception { try (executor) { - Callable task1 = () -> "foo"; - DelayedResult task2 = new DelayedResult("bar", Duration.ofMinutes(1)); - scheduleInterrupt(Thread.currentThread(), Duration.ofMillis(500)); + var task = new LongRunningTask(); try { - executor.invokeAll(Set.of(task1, task2)); + scheduleInterruptAt("java.util.concurrent.ThreadPerTaskExecutor.invokeAll"); + executor.invokeAll(Set.of(task)); fail("invokeAll did not throw"); } catch (InterruptedException expected) { assertFalse(Thread.currentThread().isInterrupted()); - // task2 should have been interrupted - while (!task2.isDone()) { - Thread.sleep(Duration.ofMillis(100)); + // if task started then it should be interrupted + if (task.isStarted()) { + task.awaitDone(); + assertTrue(task.isInterrupted()); } - assertTrue(task2.exception() instanceof InterruptedException); } finally { Thread.interrupted(); // clear interrupt } @@ -883,26 +849,25 @@ class ThreadPerTaskExecutorTest { } /** - * Test interrupt with thread blocked in timed-invokeAll. + * Test interrupt of thread blocked in timed-invokeAll. */ @ParameterizedTest @MethodSource("executors") - void testInvokeAllInterrupt6(ExecutorService executor) throws Exception { + void testInvokeAllInterrupt5(ExecutorService executor) throws Exception { try (executor) { - Callable task1 = () -> "foo"; - DelayedResult task2 = new DelayedResult("bar", Duration.ofMinutes(1)); - scheduleInterrupt(Thread.currentThread(), Duration.ofMillis(500)); + var task = new LongRunningTask(); try { - executor.invokeAll(Set.of(task1, task2), 1, TimeUnit.DAYS); + scheduleInterruptAt("java.util.concurrent.ThreadPerTaskExecutor.invokeAll"); + executor.invokeAll(Set.of(task), 1, TimeUnit.DAYS); fail("invokeAll did not throw"); } catch (InterruptedException expected) { assertFalse(Thread.currentThread().isInterrupted()); - // task2 should have been interrupted - while (!task2.isDone()) { - Thread.sleep(Duration.ofMillis(100)); + // if task started then it should be interrupted + if (task.isStarted()) { + task.awaitDone(); + assertTrue(task.isInterrupted()); } - assertTrue(task2.exception() instanceof InterruptedException); } finally { Thread.interrupted(); // clear interrupt } @@ -1040,33 +1005,98 @@ class ThreadPerTaskExecutorTest { () -> Executors.newThreadPerTaskExecutor(null)); } - // -- supporting classes -- - - static class DelayedResult implements Callable { - final T result; - final Duration delay; - volatile boolean done; - volatile Exception exception; - DelayedResult(T result, Duration delay) { - this.result = result; - this.delay = delay; - } - public T call() throws Exception { + /** + * Schedules the current thread to be interrupted when it waits (timed or untimed) + * at the given location "{@code c.m}" where {@code c} is the fully qualified class + * name and {@code m} is the method name. + */ + private void scheduleInterruptAt(String location) { + int index = location.lastIndexOf('.'); + String className = location.substring(0, index); + String methodName = location.substring(index + 1); + Thread target = Thread.currentThread(); + scheduler.submit(() -> { try { - Thread.sleep(delay); - return result; + boolean found = false; + while (!found) { + Thread.State state = target.getState(); + assertTrue(state != TERMINATED); + if ((state == WAITING || state == TIMED_WAITING) + && contains(target.getStackTrace(), className, methodName)) { + found = true; + } else { + Thread.sleep(20); + } + } + target.interrupt(); } catch (Exception e) { - this.exception = e; + e.printStackTrace(); + } + }); + } + + /** + * Returns true if the given stack trace contains an element for the given class + * and method name. + */ + private boolean contains(StackTraceElement[] stack, String className, String methodName) { + return Arrays.stream(stack) + .anyMatch(e -> className.equals(e.getClassName()) + && methodName.equals(e.getMethodName())); + } + + /** + * Long running task with methods to test if the task has started, finished, + * and interrupted. + */ + private static class LongRunningTask implements Callable { + final CountDownLatch started = new CountDownLatch(1); + final CountDownLatch done = new CountDownLatch(1); + volatile boolean interrupted; + + @Override + public T call() throws InterruptedException { + started.countDown(); + try { + Thread.sleep(Duration.ofDays(1)); + } catch (InterruptedException e) { + interrupted = true; throw e; } finally { - done = true; + done.countDown(); } + return null; } - boolean isDone() { - return done; + + /** + * Wait for the task to start execution. + */ + LongRunningTask awaitStarted() throws InterruptedException { + started.await(); + return this; } - Exception exception() { - return exception; + + /** + * Wait for the task to finish execution. + */ + LongRunningTask awaitDone() throws InterruptedException { + done.await(); + return this; + } + + /** + * Returns true if the task started execution. + */ + boolean isStarted() { + return started.getCount() == 0; + } + + /** + * Returns true if the task was interrupted. + */ + boolean isInterrupted() { + assertTrue(done.getCount() == 0); // shouldn't call before finished + return interrupted; } } }