8291897: TerminatingThreadLocal(s) not registered from virtual thread(s)

Reviewed-by: alanb
This commit is contained in:
Peter Levart 2022-08-08 12:38:21 +00:00
parent 8d88be233b
commit 861cc671e2
9 changed files with 250 additions and 66 deletions

View File

@ -67,6 +67,8 @@ import java.util.concurrent.Callable;
import java.util.function.Supplier;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
import jdk.internal.misc.CarrierThreadLocal;
import jdk.internal.misc.Unsafe;
import jdk.internal.util.StaticProperty;
import jdk.internal.module.ModuleBootstrap;
@ -2554,12 +2556,20 @@ public final class System {
}
}
public <T> T getCarrierThreadLocal(ThreadLocal<T> local) {
return local.getCarrierThreadLocal();
public <T> T getCarrierThreadLocal(CarrierThreadLocal<T> local) {
return ((ThreadLocal<T>)local).getCarrierThreadLocal();
}
public <T> void setCarrierThreadLocal(ThreadLocal<T> local, T value) {
local.setCarrierThreadLocal(value);
public <T> void setCarrierThreadLocal(CarrierThreadLocal<T> local, T value) {
((ThreadLocal<T>)local).setCarrierThreadLocal(value);
}
public void removeCarrierThreadLocal(CarrierThreadLocal<?> local) {
((ThreadLocal<?>)local).removeCarrierThreadLocal();
}
public boolean isCarrierThreadLocalPresent(CarrierThreadLocal<?> local) {
return ((ThreadLocal<?>)local).isCarrierThreadLocalPresent();
}
public Object[] extentLocalCache() {

View File

@ -29,6 +29,8 @@ import java.lang.ref.WeakReference;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import jdk.internal.misc.CarrierThreadLocal;
import jdk.internal.misc.TerminatingThreadLocal;
/**
@ -172,6 +174,7 @@ public class ThreadLocal<T> {
* thread-local variable.
*/
T getCarrierThreadLocal() {
assert this instanceof CarrierThreadLocal<T>;
return get(Thread.currentCarrierThread());
}
@ -193,14 +196,18 @@ public class ThreadLocal<T> {
}
/**
* Returns {@code true} if there is a value in the current thread's copy of
* Returns {@code true} if there is a value in the current carrier thread's copy of
* this thread-local variable, even if that values is {@code null}.
*
* @return {@code true} if current thread has associated value in this
* @return {@code true} if current carrier thread has associated value in this
* thread-local variable; {@code false} if not
*/
boolean isPresent() {
Thread t = Thread.currentThread();
boolean isCarrierThreadLocalPresent() {
assert this instanceof CarrierThreadLocal<T>;
return isPresent(Thread.currentCarrierThread());
}
private boolean isPresent(Thread t) {
ThreadLocalMap map = getMap(t);
if (map != null && map != ThreadLocalMap.NOT_SUPPORTED) {
return map.getEntry(this) != null;
@ -224,8 +231,8 @@ public class ThreadLocal<T> {
} else {
createMap(t, value);
}
if (this instanceof TerminatingThreadLocal) {
TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
if (this instanceof TerminatingThreadLocal<?> ttl) {
TerminatingThreadLocal.register(ttl);
}
return value;
}
@ -249,6 +256,7 @@ public class ThreadLocal<T> {
}
void setCarrierThreadLocal(T value) {
assert this instanceof CarrierThreadLocal<T>;
set(Thread.currentCarrierThread(), value);
}
@ -276,7 +284,16 @@ public class ThreadLocal<T> {
* @since 1.5
*/
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
remove(Thread.currentThread());
}
void removeCarrierThreadLocal() {
assert this instanceof CarrierThreadLocal<T>;
remove(Thread.currentCarrierThread());
}
private void remove(Thread t) {
ThreadLocalMap m = getMap(t);
if (m != null && m != ThreadLocalMap.NOT_SUPPORTED) {
m.remove(this);
}

View File

@ -27,7 +27,6 @@ package jdk.internal.access;
import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.module.ModuleDescriptor;
import java.lang.reflect.Executable;
@ -45,6 +44,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.stream.Stream;
import jdk.internal.misc.CarrierThreadLocal;
import jdk.internal.module.ServicesCatalog;
import jdk.internal.reflect.ConstantPool;
import jdk.internal.vm.Continuation;
@ -456,12 +456,23 @@ public interface JavaLangAccess {
/**
* Returns the value of the current carrier thread's copy of a thread-local.
*/
<T> T getCarrierThreadLocal(ThreadLocal<T> local);
<T> T getCarrierThreadLocal(CarrierThreadLocal<T> local);
/**
* Sets the value of the current carrier thread's copy of a thread-local.
*/
<T> void setCarrierThreadLocal(ThreadLocal<T> local, T value);
<T> void setCarrierThreadLocal(CarrierThreadLocal<T> local, T value);
/**
* Removes the value of the current carrier thread's copy of a thread-local.
*/
void removeCarrierThreadLocal(CarrierThreadLocal<?> local);
/**
* Returns {@code true} if there is a value in the current carrier thread's copy of
* thread-local, even if that values is {@code null}.
*/
boolean isCarrierThreadLocalPresent(CarrierThreadLocal<?> local);
/**
* Returns the current thread's extent locals cache

View File

@ -0,0 +1,57 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.internal.misc;
import jdk.internal.access.JavaLangAccess;
import jdk.internal.access.SharedSecrets;
/**
* A {@link ThreadLocal} variant which binds its value to current thread's
* carrier thread.
*/
public class CarrierThreadLocal<T> extends ThreadLocal<T> {
@Override
public T get() {
return JLA.getCarrierThreadLocal(this);
}
@Override
public void set(T value) {
JLA.setCarrierThreadLocal(this, value);
}
@Override
public void remove() {
JLA.removeCarrierThreadLocal(this);
}
public boolean isPresent() {
return JLA.isCarrierThreadLocalPresent(this);
}
private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
}

View File

@ -29,11 +29,12 @@ import java.util.Collections;
import java.util.IdentityHashMap;
/**
* A thread-local variable that is notified when a thread terminates and
* it has been initialized in the terminating thread (even if it was
* A per-carrier-thread-local variable that is notified when a thread terminates and
* it has been initialized in the terminating carrier thread or a virtual thread
* that had the terminating carrier thread as its carrier thread (even if it was
* initialized with a null value).
*/
public class TerminatingThreadLocal<T> extends ThreadLocal<T> {
public class TerminatingThreadLocal<T> extends CarrierThreadLocal<T> {
@Override
public void set(T value) {
@ -79,8 +80,7 @@ public class TerminatingThreadLocal<T> extends ThreadLocal<T> {
* @param tl the ThreadLocal to register
*/
public static void register(TerminatingThreadLocal<?> tl) {
if (!Thread.currentThread().isVirtual())
REGISTRY.get().add(tl);
REGISTRY.get().add(tl);
}
/**
@ -89,16 +89,15 @@ public class TerminatingThreadLocal<T> extends ThreadLocal<T> {
* @param tl the ThreadLocal to unregister
*/
private static void unregister(TerminatingThreadLocal<?> tl) {
if (!Thread.currentThread().isVirtual())
REGISTRY.get().remove(tl);
REGISTRY.get().remove(tl);
}
/**
* a per-thread registry of TerminatingThreadLocal(s) that have been registered
* but later not unregistered in a particular thread.
* a per-carrier-thread registry of TerminatingThreadLocal(s) that have been registered
* but later not unregistered in a particular carrier-thread.
*/
public static final ThreadLocal<Collection<TerminatingThreadLocal<?>>> REGISTRY =
new ThreadLocal<>() {
public static final CarrierThreadLocal<Collection<TerminatingThreadLocal<?>>> REGISTRY =
new CarrierThreadLocal<>() {
@Override
protected Collection<TerminatingThreadLocal<?>> initialValue() {
return Collections.newSetFromMap(new IdentityHashMap<>(4));

View File

@ -27,8 +27,7 @@ package sun.nio.ch;
import java.nio.ByteBuffer;
import jdk.internal.access.JavaLangAccess;
import jdk.internal.access.SharedSecrets;
import jdk.internal.misc.CarrierThreadLocal;
import jdk.internal.ref.CleanerFactory;
/**
@ -46,7 +45,6 @@ import jdk.internal.ref.CleanerFactory;
*/
class IOVecWrapper {
private static final JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
// Miscellaneous constants
private static final int BASE_OFFSET = 0;
@ -83,8 +81,8 @@ class IOVecWrapper {
}
}
// per thread IOVecWrapper
private static final ThreadLocal<IOVecWrapper> cached = new ThreadLocal<>();
// per carrier-thread IOVecWrapper
private static final CarrierThreadLocal<IOVecWrapper> cached = new CarrierThreadLocal<>();
private IOVecWrapper(int size) {
this.size = size;
@ -97,7 +95,7 @@ class IOVecWrapper {
}
static IOVecWrapper get(int size) {
IOVecWrapper wrapper = JLA.getCarrierThreadLocal(cached);
IOVecWrapper wrapper = cached.get();
if (wrapper != null && wrapper.size < size) {
// not big enough; eagerly release memory
wrapper.vecArray.free();
@ -106,7 +104,7 @@ class IOVecWrapper {
if (wrapper == null) {
wrapper = new IOVecWrapper(size);
CleanerFactory.cleaner().register(wrapper, new Deallocator(wrapper.vecArray));
JLA.setCarrierThreadLocal(cached, wrapper);
cached.set(wrapper);
}
return wrapper;
}

View File

@ -38,14 +38,11 @@ import java.util.Collection;
import java.util.Iterator;
import java.util.Set;
import jdk.internal.access.JavaLangAccess;
import jdk.internal.access.SharedSecrets;
import jdk.internal.misc.TerminatingThreadLocal;
import jdk.internal.misc.Unsafe;
import sun.security.action.GetPropertyAction;
public class Util {
private static JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
// -- Caches --
@ -55,8 +52,8 @@ public class Util {
// The max size allowed for a cached temp buffer, in bytes
private static final long MAX_CACHED_BUFFER_SIZE = getMaxCachedBufferSize();
// Per-thread cache of temporary direct buffers
private static ThreadLocal<BufferCache> bufferCache = new TerminatingThreadLocal<>() {
// Per-carrier-thread cache of temporary direct buffers
private static TerminatingThreadLocal<BufferCache> bufferCache = new TerminatingThreadLocal<>() {
@Override
protected BufferCache initialValue() {
return new BufferCache();
@ -230,7 +227,7 @@ public class Util {
return ByteBuffer.allocateDirect(size);
}
BufferCache cache = JLA.getCarrierThreadLocal(bufferCache);
BufferCache cache = bufferCache.get();
ByteBuffer buf = cache.get(size);
if (buf != null) {
return buf;
@ -257,7 +254,7 @@ public class Util {
.alignedSlice(alignment);
}
BufferCache cache = JLA.getCarrierThreadLocal(bufferCache);
BufferCache cache = bufferCache.get();
ByteBuffer buf = cache.get(size);
if (buf != null) {
if (buf.alignmentOffset(0, alignment) == 0) {
@ -294,7 +291,7 @@ public class Util {
}
assert buf != null;
BufferCache cache = JLA.getCarrierThreadLocal(bufferCache);
BufferCache cache = bufferCache.get();
if (!cache.offerFirst(buf)) {
// cache is full
free(buf);
@ -316,7 +313,7 @@ public class Util {
}
assert buf != null;
BufferCache cache = JLA.getCarrierThreadLocal(bufferCache);
BufferCache cache = bufferCache.get();
if (!cache.offerLast(buf)) {
// cache is full
free(buf);

View File

@ -25,8 +25,6 @@
package sun.nio.fs;
import jdk.internal.access.JavaLangAccess;
import jdk.internal.access.SharedSecrets;
import jdk.internal.misc.TerminatingThreadLocal;
import jdk.internal.misc.Unsafe;
@ -35,12 +33,12 @@ import jdk.internal.misc.Unsafe;
*/
class NativeBuffers {
private static JavaLangAccess JLA = SharedSecrets.getJavaLangAccess();
private static final Unsafe unsafe = Unsafe.getUnsafe();
private static final int TEMP_BUF_POOL_SIZE = 3;
private static ThreadLocal<NativeBuffer[]> threadLocal = new TerminatingThreadLocal<>() {
// per-carrier-thread cache of NativeBuffer(s)
private static final TerminatingThreadLocal<NativeBuffer[]> threadLocal = new TerminatingThreadLocal<>() {
@Override
protected void threadTerminated(NativeBuffer[] buffers) {
// threadLocal may be initialized but with initialValue of null
@ -73,7 +71,7 @@ class NativeBuffers {
*/
static NativeBuffer getNativeBufferFromCache(int size) {
// return from cache if possible
NativeBuffer[] buffers = JLA.getCarrierThreadLocal(threadLocal);
NativeBuffer[] buffers = threadLocal.get();
if (buffers != null) {
for (int i=0; i<TEMP_BUF_POOL_SIZE; i++) {
NativeBuffer buffer = buffers[i];
@ -107,11 +105,11 @@ class NativeBuffers {
*/
static void releaseNativeBuffer(NativeBuffer buffer) {
// create cache if it doesn't exist
NativeBuffer[] buffers = JLA.getCarrierThreadLocal(threadLocal);
NativeBuffer[] buffers = threadLocal.get();
if (buffers == null) {
buffers = new NativeBuffer[TEMP_BUF_POOL_SIZE];
buffers[0] = buffer;
JLA.setCarrierThreadLocal(threadLocal, buffers);
threadLocal.set(buffers);
return;
}
// Put it in an empty slot if such exists

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -26,39 +26,44 @@ import jdk.internal.misc.TerminatingThreadLocal;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
/*
* @test
* @bug 8202788
* @bug 8202788 8291897
* @summary TerminatingThreadLocal unit test
* @modules java.base/jdk.internal.misc
* @run main TestTerminatingThreadLocal
* @requires vm.continuations
* @enablePreview
* @run main/othervm -Djdk.virtualThreadScheduler.parallelism=1 -Djdk.virtualThreadScheduler.maxPoolSize=2 TestTerminatingThreadLocal
*/
public class TestTerminatingThreadLocal {
public static void main(String[] args) {
ttlTestSet(42, 112);
ttlTestSet(42, 112);
ttlTestSet(null, 112);
ttlTestSet(42, null);
ttlTestSet(42, null);
ttlTestVirtual(666, ThreadLocal::get, 666);
}
static <T> void ttlTestSet(T v0, T v1) {
ttlTest(v0, ttl -> { } );
ttlTest(v0, ttl -> { ttl.get(); }, v0);
ttlTest(v0, ttl -> { ttl.get(); ttl.remove(); } );
ttlTest(v0, ttl -> { ttl.get(); ttl.set(v1); }, v1);
ttlTest(v0, ttl -> { ttl.set(v1); }, v1);
ttlTest(v0, ttl -> { ttl.set(v1); ttl.remove(); } );
ttlTest(v0, ttl -> { ttl.set(v1); ttl.remove(); ttl.get(); }, v0);
ttlTest(v0, ttl -> { ttl.get(); ttl.remove(); ttl.set(v1); }, v1);
ttlTestPlatform(v0, ttl -> { } );
ttlTestPlatform(v0, ttl -> { ttl.get(); }, v0);
ttlTestPlatform(v0, ttl -> { ttl.get(); ttl.remove(); } );
ttlTestPlatform(v0, ttl -> { ttl.get(); ttl.set(v1); }, v1);
ttlTestPlatform(v0, ttl -> { ttl.set(v1); }, v1);
ttlTestPlatform(v0, ttl -> { ttl.set(v1); ttl.remove(); } );
ttlTestPlatform(v0, ttl -> { ttl.set(v1); ttl.remove(); ttl.get(); }, v0);
ttlTestPlatform(v0, ttl -> { ttl.get(); ttl.remove(); ttl.set(v1); }, v1);
}
@SafeVarargs
static <T> void ttlTest(T initialValue,
Consumer<? super TerminatingThreadLocal<T>> ttlOps,
T... expectedTerminatedValues)
{
static <T> void ttlTestPlatform(T initialValue,
Consumer<? super TerminatingThreadLocal<T>> ttlOps,
T... expectedTerminatedValues) {
List<T> terminatedValues = new CopyOnWriteArrayList<>();
TerminatingThreadLocal<T> ttl = new TerminatingThreadLocal<>() {
@ -73,7 +78,7 @@ public class TestTerminatingThreadLocal {
}
};
Thread thread = new Thread(() -> ttlOps.accept(ttl));
Thread thread = new Thread(() -> ttlOps.accept(ttl), "ttl-test-platform");
thread.start();
try {
thread.join();
@ -87,4 +92,96 @@ public class TestTerminatingThreadLocal {
" but got: " + terminatedValues);
}
}
@SafeVarargs
static <T> void ttlTestVirtual(T initialValue,
Consumer<? super TerminatingThreadLocal<T>> ttlOps,
T... expectedTerminatedValues) {
List<T> terminatedValues = new CopyOnWriteArrayList<>();
TerminatingThreadLocal<T> ttl = new TerminatingThreadLocal<>() {
@Override
protected void threadTerminated(T value) {
terminatedValues.add(value);
}
@Override
protected T initialValue() {
return initialValue;
}
};
var lock = new Lock();
var blockerThread = Thread.startVirtualThread(() -> {
// force compensation in carrier thread pool which will spin another
// carrier thread so that we can later observe it being terminated...
synchronized (lock) {
while (!lock.unblock) {
try {
lock.wait();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
// keep thread running in a non-blocking-fashion which keeps
// it bound to carrier thread
while (!lock.unspin) {
Thread.onSpinWait();
}
});
Thread thread = Thread
.ofVirtual()
.allowSetThreadLocals(false)
.inheritInheritableThreadLocals(false)
.name("ttl-test-virtual")
.unstarted(() -> ttlOps.accept(ttl));
thread.start();
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
if (!terminatedValues.isEmpty()) {
throw new AssertionError("Unexpected terminated values after virtual thread.join(): " +
terminatedValues);
}
// we now unblock the blocker thread but keep it running
synchronized (lock) {
lock.unblock = true;
lock.notify();
}
// carrier thread pool has a 30 second keep-alive time to terminate excessive carrier
// threads. Since blockerThread is still pinning one of them we hope for the other
// thread to be terminated...
try {
TimeUnit.SECONDS.sleep(31);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
if (!terminatedValues.equals(Arrays.asList(expectedTerminatedValues))) {
throw new AssertionError("Expected terminated values: " +
Arrays.toString(expectedTerminatedValues) +
" but got: " + terminatedValues);
}
// we now terminate the blocker thread
lock.unspin = true;
try {
blockerThread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
static class Lock {
boolean unblock;
volatile boolean unspin;
}
}