mirror of
https://github.com/openjdk/jdk.git
synced 2026-06-13 14:05:08 +00:00
544 lines
19 KiB
Java
544 lines
19 KiB
Java
/*
|
|
* Copyright (c) 2011, 2025, Oracle and/or its affiliates. All rights reserved.
|
|
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
|
*
|
|
* This code is free software; you can redistribute it and/or modify it
|
|
* under the terms of the GNU General Public License version 2 only, as
|
|
* published by the Free Software Foundation.
|
|
*
|
|
* This code is distributed in the hope that it will be useful, but WITHOUT
|
|
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
|
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
|
* version 2 for more details (a copy is included in the LICENSE file that
|
|
* accompanied this code).
|
|
*
|
|
* You should have received a copy of the GNU General Public License version
|
|
* 2 along with this work; if not, write to the Free Software Foundation,
|
|
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
*
|
|
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
|
* or visit www.oracle.com if you need additional information or have any
|
|
* questions.
|
|
*/
|
|
|
|
/*
|
|
* @test
|
|
* @bug 8351045 8351996 8358535
|
|
* @summary tests for class-specific values
|
|
* @modules java.base/java.lang:+open
|
|
* @library /test/lib
|
|
* @run junit ClassValueTest
|
|
*/
|
|
|
|
import java.lang.classfile.ClassFile;
|
|
import java.lang.constant.ClassDesc;
|
|
import java.lang.invoke.MethodHandle;
|
|
import java.lang.invoke.MethodHandles;
|
|
import java.lang.invoke.MethodType;
|
|
import java.lang.ref.WeakReference;
|
|
import java.time.Duration;
|
|
import java.time.temporal.ChronoUnit;
|
|
import java.util.ArrayList;
|
|
import java.util.Arrays;
|
|
import java.util.List;
|
|
import java.util.concurrent.CountDownLatch;
|
|
import java.util.concurrent.ThreadLocalRandom;
|
|
import java.util.concurrent.TimeUnit;
|
|
import java.util.concurrent.atomic.AtomicInteger;
|
|
import java.util.concurrent.atomic.AtomicReference;
|
|
|
|
import jdk.test.lib.util.ForceGC;
|
|
import org.junit.jupiter.api.Disabled;
|
|
import org.junit.jupiter.api.Test;
|
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
|
|
/**
|
|
* @author jrose
|
|
*/
|
|
final class ClassValueTest {
|
|
static String nameForCV1(Class<?> type) {
|
|
return "CV1:" + type.getName();
|
|
}
|
|
int countForCV1;
|
|
final ClassValue<String> CV1 = new CV1();
|
|
private class CV1 extends ClassValue<String> {
|
|
protected String computeValue(Class<?> type) {
|
|
countForCV1++;
|
|
return nameForCV1(type);
|
|
}
|
|
}
|
|
|
|
static final Class<?>[] CLASSES = {
|
|
String.class,
|
|
Integer.class,
|
|
int.class,
|
|
boolean[].class,
|
|
char[][].class,
|
|
ClassValueTest.class
|
|
};
|
|
|
|
@Test
|
|
public void testGet() {
|
|
countForCV1 = 0;
|
|
for (Class<?> c : CLASSES) {
|
|
assertEquals(nameForCV1(c), CV1.get(c));
|
|
}
|
|
assertEquals(CLASSES.length, countForCV1);
|
|
for (Class<?> c : CLASSES) {
|
|
assertEquals(nameForCV1(c), CV1.get(c));
|
|
}
|
|
assertEquals(CLASSES.length, countForCV1);
|
|
}
|
|
|
|
@Test
|
|
public void testRemove() {
|
|
for (Class<?> c : CLASSES) {
|
|
CV1.get(c);
|
|
}
|
|
countForCV1 = 0;
|
|
int REMCOUNT = 3;
|
|
for (int i = 0; i < REMCOUNT; i++) {
|
|
CV1.remove(CLASSES[i]);
|
|
}
|
|
assertEquals(0, countForCV1); // no change
|
|
for (Class<?> c : CLASSES) {
|
|
assertEquals(nameForCV1(c), CV1.get(c));
|
|
}
|
|
assertEquals(REMCOUNT, countForCV1);
|
|
}
|
|
|
|
static String nameForCVN(Class<?> type, int n) {
|
|
return "CV[" + n + "]" + type.getName();
|
|
}
|
|
int countForCVN;
|
|
class CVN extends ClassValue<String> {
|
|
final int n;
|
|
CVN(int n) { this.n = n; }
|
|
protected String computeValue(Class<?> type) {
|
|
countForCVN++;
|
|
return nameForCVN(type, n);
|
|
}
|
|
};
|
|
|
|
@Test
|
|
public void testGetMany() {
|
|
int CVN_COUNT1 = 100, CVN_COUNT2 = 100;
|
|
CVN cvns[] = new CVN[CVN_COUNT1 * CVN_COUNT2];
|
|
for (int n = 0; n < cvns.length; n++) {
|
|
cvns[n] = new CVN(n);
|
|
}
|
|
countForCVN = 0;
|
|
for (int pass = 0; pass <= 2; pass++) {
|
|
for (int i1 = 0; i1 < CVN_COUNT1; i1++) {
|
|
eachClass:
|
|
for (Class<?> c : CLASSES) {
|
|
for (int i2 = 0; i2 < CVN_COUNT2; i2++) {
|
|
int n = i1*CVN_COUNT2 + i2;
|
|
assertEquals(0, countForCVN);
|
|
assertEquals(nameForCVN(c, n), cvns[n].get(c));
|
|
cvns[n].get(c); //get it again
|
|
//System.out.println("getting "+n+":"+cvns[n].get(c));
|
|
boolean doremove = (((i1 + i2) & 3) == 0);
|
|
switch (pass) {
|
|
case 0:
|
|
assertEquals(1, countForCVN);
|
|
break;
|
|
case 1:
|
|
// remove on middle pass
|
|
assertEquals(0, countForCVN);
|
|
if (doremove) {
|
|
//System.out.println("removing "+n+":"+cvns[n].get(c));
|
|
cvns[n].remove(c);
|
|
assertEquals(0, countForCVN);
|
|
}
|
|
break;
|
|
case 2:
|
|
assertEquals(doremove ? 1 : 0, countForCVN);
|
|
break;
|
|
}
|
|
countForCVN = 0;
|
|
if (i1 > i2 && i1 < i2+5) continue eachClass; // leave diagonal gap
|
|
}
|
|
}
|
|
}
|
|
}
|
|
assertEquals(0, countForCVN);
|
|
System.out.println("[rechecking values]");
|
|
for (int i = 0; i < cvns.length * 10; i++) {
|
|
int n = i % cvns.length;
|
|
for (Class<?> c : CLASSES) {
|
|
assertEquals(nameForCVN(c, n), cvns[n].get(c));
|
|
}
|
|
}
|
|
}
|
|
|
|
private static final long COMPUTE_TIME_MILLIS = 100;
|
|
// Adjust this timeout to fail faster for test stalls
|
|
private static final Duration TIMEOUT = Duration.ofNanos((long) (
|
|
Duration.of(1, ChronoUnit.MINUTES).toNanos()
|
|
* Double.parseDouble(System.getProperty("test.timeout.factor", "1.0"))));
|
|
|
|
private static void await(CountDownLatch latch) {
|
|
try {
|
|
if (!latch.await(TIMEOUT.toNanos(), TimeUnit.NANOSECONDS)) {
|
|
fail("No signal received");
|
|
}
|
|
} catch (InterruptedException e) {
|
|
fail(e);
|
|
}
|
|
}
|
|
|
|
private static void awaitThreads(Thread... threads) {
|
|
awaitThreads(Arrays.asList(threads));
|
|
}
|
|
|
|
private static void awaitThreads(Iterable<Thread> threads) {
|
|
for (var t : threads) {
|
|
try {
|
|
if (!t.join(TIMEOUT)) {
|
|
fail("Thread not stopping " + t);
|
|
}
|
|
} catch (InterruptedException ex) {
|
|
fail(ex);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Tests that get() + remove() can prevent stale value from being installed.
|
|
* Uses junit to do basic stress.
|
|
*/
|
|
@Test
|
|
void testRemoveStale() throws InterruptedException {
|
|
CountDownLatch oldInputUsed = new CountDownLatch(1);
|
|
CountDownLatch inputUpdated = new CountDownLatch(1);
|
|
AtomicInteger input = new AtomicInteger(0);
|
|
ClassValue<Integer> cv = new ClassValue<>() {
|
|
@Override
|
|
protected Integer computeValue(Class<?> type) {
|
|
// must get early to represent using outdated input
|
|
int v = input.get();
|
|
oldInputUsed.countDown();
|
|
// ensure input is updated when we return
|
|
await(inputUpdated);
|
|
return v;
|
|
}
|
|
};
|
|
var innocuous = Thread.startVirtualThread(() -> cv.get(int.class));
|
|
var refreshInput = Thread.startVirtualThread(() -> {
|
|
await(oldInputUsed);
|
|
input.incrementAndGet();
|
|
cv.remove(int.class); // Let's recompute with updated inputs!
|
|
inputUpdated.countDown();
|
|
});
|
|
awaitThreads(innocuous, refreshInput);
|
|
assertEquals(1, input.get(), "input not updated");
|
|
assertEquals(1, cv.get(int.class), "CV not using up-to-date input");
|
|
}
|
|
|
|
/**
|
|
* Tests that calling get() from computeValue() terminates.
|
|
*/
|
|
@Test
|
|
void testGetInCompute() {
|
|
ClassValue<Object> cv = new ClassValue<>() {
|
|
@Override
|
|
protected Object computeValue(Class<?> type) {
|
|
get(type);
|
|
get(type);
|
|
get(type);
|
|
return Boolean.TRUE;
|
|
}
|
|
};
|
|
try {
|
|
cv.get(int.class);
|
|
} catch (Throwable ex) {
|
|
// swallow if any
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Tests that calling remove() from computeValue() terminates.
|
|
*/
|
|
@Test
|
|
void testRemoveInCompute() {
|
|
ClassValue<Boolean> cv = new ClassValue<>() {
|
|
@Override
|
|
protected Boolean computeValue(Class<?> type) {
|
|
remove(type);
|
|
remove(type);
|
|
remove(type);
|
|
return Boolean.TRUE;
|
|
}
|
|
};
|
|
assertTrue(cv.get(int.class));
|
|
}
|
|
|
|
private static Class<?> createWeakClass() {
|
|
var bytes = ClassFile.of().build(ClassDesc.of("WeakTest"), _ -> {});
|
|
try {
|
|
return MethodHandles.lookup().defineHiddenClass(bytes, true).lookupClass();
|
|
} catch (IllegalAccessException ex) {
|
|
return fail(ex);
|
|
}
|
|
}
|
|
|
|
@Test
|
|
void testWeakAgainstClass() {
|
|
Class<?> hidden = createWeakClass();
|
|
ClassValue<int[]> cv = new ClassValue<>() {
|
|
@Override
|
|
protected int[] computeValue(Class<?> type) {
|
|
return new int[23];
|
|
}
|
|
};
|
|
|
|
WeakReference<?> ref = new WeakReference<>(cv.get(hidden));
|
|
hidden = null; // Remove reference for interpreter
|
|
if (!ForceGC.wait(() -> ref.refersTo(null))) {
|
|
fail("Timeout");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
@Disabled // JDK-8352622
|
|
void testWeakAgainstClassValue() {
|
|
ClassValue<int[]> cv = new ClassValue<>() {
|
|
@Override
|
|
protected int[] computeValue(Class<?> type) {
|
|
return new int[23];
|
|
}
|
|
};
|
|
|
|
WeakReference<?> ref = new WeakReference<>(cv.get(int.class));
|
|
cv = null; // Remove reference for interpreter
|
|
if (!ForceGC.wait(() -> {
|
|
CV1.get(int.class); // flush the weak maps
|
|
return ref.refersTo(null);
|
|
})) {
|
|
fail("Timeout");
|
|
}
|
|
}
|
|
|
|
@Test
|
|
void testSingletonWinner() {
|
|
CountDownLatch raceStart = new CountDownLatch(1);
|
|
ClassValue<int[]> cv = new ClassValue<>() {
|
|
@Override
|
|
protected int[] computeValue(Class<?> type) {
|
|
await(raceStart);
|
|
return new int[] {ThreadLocalRandom.current().nextInt()};
|
|
}
|
|
};
|
|
AtomicReference<int[]> truth = new AtomicReference<>(null);
|
|
AtomicInteger truthSwapCount = new AtomicInteger(0);
|
|
|
|
List<Thread> threads = new ArrayList<>(100);
|
|
Runnable job = () -> {
|
|
var res = cv.get(ClassValueTest.class);
|
|
var item = truth.compareAndExchange(null, res);
|
|
if (item != null) {
|
|
assertSame(item, res);
|
|
} else {
|
|
truthSwapCount.incrementAndGet();
|
|
}
|
|
};
|
|
for (int i = 0; i < 100; i++) {
|
|
threads.add(Thread.startVirtualThread(job));
|
|
}
|
|
raceStart.countDown();
|
|
awaitThreads(threads);
|
|
assertEquals(1, truthSwapCount.get());
|
|
}
|
|
|
|
@Test
|
|
void testRacyRemoveInCompute() {
|
|
ClassValue<Object> cv = new ClassValue<>() {
|
|
@Override
|
|
protected Object computeValue(Class<?> type) {
|
|
remove(type);
|
|
try {
|
|
Thread.sleep(COMPUTE_TIME_MILLIS);
|
|
} catch (InterruptedException ex) {
|
|
throw new RuntimeException(ex);
|
|
}
|
|
remove(type);
|
|
return Boolean.TRUE;
|
|
}
|
|
};
|
|
|
|
var threads = Arrays.stream(CLASSES)
|
|
.map(clz -> Thread.startVirtualThread(() -> cv.get(clz)))
|
|
.toList();
|
|
awaitThreads(threads);
|
|
}
|
|
|
|
@Test
|
|
void testRecursiveInitialization() {
|
|
record Holder() {
|
|
static final ClassValue<Object> clv = new ClassValue<>() {
|
|
@Override
|
|
protected Object computeValue(Class<?> type) {
|
|
return new One();
|
|
}
|
|
};
|
|
|
|
record One() {
|
|
static {
|
|
Holder.clv.get(One.class);
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
Holder.clv.get(Holder.One.class);
|
|
}
|
|
|
|
private static final ScopedValue<Integer> THREAD_ID = ScopedValue.newInstance();
|
|
|
|
@Test
|
|
void testNoRecomputeOnUnrelatedRemoval() throws InterruptedException {
|
|
CountDownLatch t1Started = new CountDownLatch(1);
|
|
CountDownLatch removeTamper = new CountDownLatch(1);
|
|
CountDownLatch t2Started = new CountDownLatch(1);
|
|
CountDownLatch t1Returned = new CountDownLatch(1);
|
|
AtomicInteger t1Tries = new AtomicInteger();
|
|
ClassValue<Object> clv = new ClassValue<>() {
|
|
@Override
|
|
protected Object computeValue(Class<?> type) {
|
|
int id = THREAD_ID.get();
|
|
if (id == 1) {
|
|
t1Tries.incrementAndGet();
|
|
t1Started.countDown();
|
|
await(t2Started); // implies unrelated changes
|
|
} else if (id == 2) {
|
|
t2Started.countDown();
|
|
// Don't race to install - we want to check computeValue tampers
|
|
await(t1Returned);
|
|
} else if (id != -1) {
|
|
fail("Id: " + id);
|
|
}
|
|
return "";
|
|
}
|
|
};
|
|
|
|
ScopedValue.where(THREAD_ID, -1).run(() -> clv.get(long.class)); // set up unrelated class
|
|
var t1 = Thread.startVirtualThread(() ->
|
|
ScopedValue.where(THREAD_ID, 1).run(() -> {
|
|
clv.get(int.class);
|
|
t1Returned.countDown(); // returned after x calls to computeValue
|
|
}));
|
|
var t2 = Thread.startVirtualThread(() ->
|
|
ScopedValue.where(THREAD_ID, 2).run(() -> {
|
|
await(removeTamper);
|
|
clv.get(int.class); // clv version diff from that of promise
|
|
}));
|
|
await(t1Started);
|
|
clv.remove(long.class);
|
|
removeTamper.countDown(); // removed unrelated class
|
|
awaitThreads(t1, t2);
|
|
assertEquals(1, t1Tries.get(), "Redundant computeValue retries");
|
|
}
|
|
|
|
@Test
|
|
void testNoObsoleteInstallation() throws InterruptedException {
|
|
CountDownLatch slowComputationStart = new CountDownLatch(1);
|
|
CountDownLatch slowComputationContinue = new CountDownLatch(1);
|
|
ClassValue<Integer> clv = new ClassValue<>() {
|
|
@Override
|
|
protected Integer computeValue(Class<?> type) {
|
|
if (!THREAD_ID.isBound()) {
|
|
return 5;
|
|
}
|
|
int threadId = THREAD_ID.get();
|
|
if (threadId == 1) {
|
|
if (slowComputationContinue.getCount() == 0) {
|
|
return 42;
|
|
} else {
|
|
// First invocation
|
|
slowComputationStart.countDown();
|
|
await(slowComputationContinue);
|
|
return -1;
|
|
}
|
|
} else {
|
|
return fail("Unknown thread " + threadId);
|
|
}
|
|
}
|
|
};
|
|
var t = Thread.startVirtualThread(() -> ScopedValue.where(THREAD_ID, 1).run(() -> {
|
|
int v = clv.get(int.class);
|
|
assertEquals(42, v, "recomputed value");
|
|
}));
|
|
await(slowComputationStart);
|
|
assertEquals(5, clv.get(int.class), "fast computation installed value");
|
|
clv.remove(int.class);
|
|
slowComputationContinue.countDown();
|
|
awaitThreads(t);
|
|
assertEquals(42, clv.get(int.class), "slow computation reinstalled value");
|
|
}
|
|
|
|
// ClassValue cache invalidated and not reinstated when another
|
|
// unrelated entry is removed
|
|
@Test
|
|
public void testCacheRefresh() throws Throwable {
|
|
// Setup
|
|
var lookup = MethodHandles.privateLookupIn(ClassValue.class, MethodHandles.lookup());
|
|
var classValueEntryClass = Class.forName("java.lang.ClassValue$Entry");
|
|
MethodHandle getCacheCarefully = lookup.findStatic(ClassValue.class, "getCacheCarefully",
|
|
MethodType.methodType(classValueEntryClass.arrayType(), Class.class));
|
|
var classValueMapClass = Class.forName("java.lang.ClassValue$ClassValueMap");
|
|
MethodHandle probeHomeLocation = lookup.findStatic(classValueMapClass, "probeHomeLocation",
|
|
MethodType.methodType(classValueEntryClass, classValueEntryClass.arrayType(), ClassValue.class));
|
|
MethodHandle match = lookup.findVirtual(ClassValue.class, "match",
|
|
MethodType.methodType(boolean.class, classValueEntryClass));
|
|
|
|
// Work
|
|
ClassValue<?> clv = new ClassValue<>() {
|
|
@Override
|
|
protected String computeValue(Class<?> type) {
|
|
return "";
|
|
}
|
|
};
|
|
// A class that shouldn't have arbitrary values stuffing the cache
|
|
var cleanClass = clv.getClass();
|
|
clv.get(cleanClass); // create cache on clean class
|
|
assertTrue(checkDirectCacheMatch(
|
|
getCacheCarefully,
|
|
probeHomeLocation,
|
|
match,
|
|
clv,
|
|
cleanClass
|
|
));
|
|
clv.get(int.class);
|
|
clv.remove(int.class); // invalidate cache on clean class
|
|
assertFalse(checkDirectCacheMatch(
|
|
getCacheCarefully,
|
|
probeHomeLocation,
|
|
match,
|
|
clv,
|
|
cleanClass
|
|
));
|
|
clv.get(cleanClass);
|
|
assertTrue(checkDirectCacheMatch(
|
|
getCacheCarefully,
|
|
probeHomeLocation,
|
|
match,
|
|
clv,
|
|
cleanClass
|
|
));
|
|
}
|
|
|
|
private static boolean checkDirectCacheMatch(
|
|
MethodHandle getCacheCarefully,
|
|
MethodHandle probeHomeLocation,
|
|
MethodHandle match,
|
|
ClassValue<?> clv,
|
|
Class<?> cl
|
|
) throws Throwable {
|
|
Object cache = getCacheCarefully.invoke(cl);
|
|
Object entry = probeHomeLocation.invoke(cache, clv);
|
|
return (boolean) match.invoke(clv, entry);
|
|
}
|
|
}
|