diff --git a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java index 6810fa6222d..79dc47e6816 100644 --- a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java +++ b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java @@ -26,17 +26,29 @@ package java.lang.runtime; import java.lang.Enum.EnumDesc; +import java.lang.constant.ClassDesc; +import java.lang.constant.ConstantDescs; +import java.lang.constant.MethodTypeDesc; import java.lang.invoke.CallSite; import java.lang.invoke.ConstantCallSite; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.lang.reflect.AccessFlag; +import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Optional; +import java.util.function.BiPredicate; import java.util.stream.Stream; import jdk.internal.access.SharedSecrets; +import jdk.internal.classfile.Classfile; +import jdk.internal.classfile.Label; +import jdk.internal.classfile.instruction.SwitchCase; import jdk.internal.vm.annotation.Stable; +import static java.lang.invoke.MethodHandles.Lookup.ClassOption.NESTMATE; +import static java.lang.invoke.MethodHandles.Lookup.ClassOption.STRONG; import static java.util.Objects.requireNonNull; /** @@ -54,26 +66,16 @@ public class SwitchBootstraps { private static final Object SENTINEL = new Object(); private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); - private static final MethodHandle INSTANCEOF_CHECK; - private static final MethodHandle INTEGER_EQ_CHECK; - private static final MethodHandle OBJECT_EQ_CHECK; - private static final MethodHandle ENUM_EQ_CHECK; private static final MethodHandle NULL_CHECK; private static final MethodHandle IS_ZERO; private static final MethodHandle CHECK_INDEX; private static final MethodHandle MAPPED_ENUM_LOOKUP; + private static final MethodTypeDesc TYPES_SWITCH_DESCRIPTOR = + MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiPredicate;Ljava/util/List;)I"); + static { try { - INSTANCEOF_CHECK = MethodHandles.permuteArguments(LOOKUP.findVirtual(Class.class, "isInstance", - MethodType.methodType(boolean.class, Object.class)), - MethodType.methodType(boolean.class, Object.class, Class.class), 1, 0); - INTEGER_EQ_CHECK = LOOKUP.findStatic(SwitchBootstraps.class, "integerEqCheck", - MethodType.methodType(boolean.class, Object.class, Integer.class)); - OBJECT_EQ_CHECK = LOOKUP.findStatic(Objects.class, "equals", - MethodType.methodType(boolean.class, Object.class, Object.class)); - ENUM_EQ_CHECK = LOOKUP.findStatic(SwitchBootstraps.class, "enumEqCheck", - MethodType.methodType(boolean.class, Object.class, EnumDesc.class, MethodHandles.Lookup.class, ResolvedEnumLabel.class)); NULL_CHECK = LOOKUP.findStatic(Objects.class, "isNull", MethodType.methodType(boolean.class, Object.class)); IS_ZERO = LOOKUP.findStatic(SwitchBootstraps.class, "isZero", @@ -155,7 +157,9 @@ public class SwitchBootstraps { labels = labels.clone(); Stream.of(labels).forEach(SwitchBootstraps::verifyLabel); - MethodHandle target = createMethodHandleSwitch(lookup, labels); + MethodHandle target = generateInnerClass(lookup, labels); + + target = withIndexCheck(target, labels.length); return new ConstantCallSite(target); } @@ -173,79 +177,6 @@ public class SwitchBootstraps { } } - /* - * Construct test chains for labels inside switch, to handle switch repeats: - * switch (idx) { - * case 0 -> if (selector matches label[0]) return 0; else if (selector matches label[1]) return 1; else ... - * case 1 -> if (selector matches label[1]) return 1; else ... - * ... - * } - */ - private static MethodHandle createRepeatIndexSwitch(MethodHandles.Lookup lookup, Object[] labels) { - MethodHandle def = MethodHandles.dropArguments(MethodHandles.constant(int.class, labels.length), 0, Object.class); - MethodHandle[] testChains = new MethodHandle[labels.length]; - List labelsList = List.of(labels).reversed(); - - for (int i = 0; i < labels.length; i++) { - MethodHandle test = def; - int idx = labels.length - 1; - List currentLabels = labelsList.subList(0, labels.length - i); - - for (int j = 0; j < currentLabels.size(); j++, idx--) { - Object currentLabel = currentLabels.get(j); - if (j + 1 < currentLabels.size() && currentLabels.get(j + 1) == currentLabel) continue; - MethodHandle currentTest; - if (currentLabel instanceof Class) { - currentTest = INSTANCEOF_CHECK; - } else if (currentLabel instanceof Integer) { - currentTest = INTEGER_EQ_CHECK; - } else if (currentLabel instanceof EnumDesc) { - currentTest = MethodHandles.insertArguments(ENUM_EQ_CHECK, 2, lookup, new ResolvedEnumLabel()); - } else { - currentTest = OBJECT_EQ_CHECK; - } - test = MethodHandles.guardWithTest(MethodHandles.insertArguments(currentTest, 1, currentLabel), - MethodHandles.dropArguments(MethodHandles.constant(int.class, idx), 0, Object.class), - test); - } - testChains[i] = MethodHandles.dropArguments(test, 0, int.class); - } - - return MethodHandles.tableSwitch(MethodHandles.dropArguments(def, 0, int.class), testChains); - } - - /* - * Construct code that maps the given selector and repeat index to a case label number: - * if (selector == null) return -1; - * else return "createRepeatIndexSwitch(labels)" - */ - private static MethodHandle createMethodHandleSwitch(MethodHandles.Lookup lookup, Object[] labels) { - MethodHandle mainTest; - MethodHandle def = MethodHandles.dropArguments(MethodHandles.constant(int.class, labels.length), 0, Object.class); - if (labels.length > 0) { - mainTest = createRepeatIndexSwitch(lookup, labels); - } else { - mainTest = MethodHandles.dropArguments(def, 0, int.class); - } - MethodHandle body = - MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class), - MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class), - mainTest); - MethodHandle switchImpl = - MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0); - return withIndexCheck(switchImpl, labels.length); - } - - private static boolean integerEqCheck(Object value, Integer constant) { - if (value instanceof Number input && constant.intValue() == input.intValue()) { - return true; - } else if (value instanceof Character input && constant.intValue() == input.charValue()) { - return true; - } - - return false; - } - private static boolean isZero(int value) { return value == 0; } @@ -330,16 +261,16 @@ public class SwitchBootstraps { //If all labels are enum constants, construct an optimized handle for repeat index 0: //if (selector == null) return -1 //else if (idx == 0) return mappingArray[selector.ordinal()]; //mapping array created lazily - //else return "createRepeatIndexSwitch(labels)" + //else return "typeSwitch(labels)" MethodHandle body = MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class), MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class), MethodHandles.guardWithTest(MethodHandles.dropArguments(IS_ZERO, 1, Object.class), - createRepeatIndexSwitch(lookup, labels), + generateInnerClass(lookup, labels), MethodHandles.insertArguments(MAPPED_ENUM_LOOKUP, 1, lookup, enumClass, labels, new EnumMap()))); target = MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0); } else { - target = createMethodHandleSwitch(lookup, labels); + target = generateInnerClass(lookup, labels); } target = target.asType(invocationType); @@ -360,7 +291,7 @@ public class SwitchBootstraps { } return label; } else if (labelClass == String.class) { - return EnumDesc.of(enumClassTemplate.describeConstable().get(), (String) label); + return EnumDesc.of(enumClassTemplate.describeConstable().orElseThrow(), (String) label); } else { throw new IllegalArgumentException("label with illegal type found: " + labelClass + ", expected label of type either String or Class"); @@ -389,45 +320,225 @@ public class SwitchBootstraps { return enumMap.map[value.ordinal()]; } - private static boolean enumEqCheck(Object value, EnumDesc label, MethodHandles.Lookup lookup, ResolvedEnumLabel resolvedEnum) { - if (resolvedEnum.resolvedEnum == null) { - Object resolved; - - try { - if (!(value instanceof Enum enumValue)) { - return false; - } - - Class clazz = label.constantType().resolveConstantDesc(lookup); - - if (enumValue.getDeclaringClass() != clazz) { - return false; - } - - resolved = label.resolveConstantDesc(lookup); - } catch (IllegalArgumentException | ReflectiveOperationException ex) { - resolved = SENTINEL; - } - - resolvedEnum.resolvedEnum = resolved; - } - - return value == resolvedEnum.resolvedEnum; - } - private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount) { MethodHandle checkIndex = MethodHandles.insertArguments(CHECK_INDEX, 1, labelsCount + 1); return MethodHandles.filterArguments(target, 1, checkIndex); } - private static final class ResolvedEnumLabel { + private static final class ResolvedEnumLabels implements BiPredicate { + + private final MethodHandles.Lookup lookup; + private final EnumDesc[] enumDescs; @Stable - public Object resolvedEnum; + private Object[] resolvedEnum; + + public ResolvedEnumLabels(MethodHandles.Lookup lookup, EnumDesc[] enumDescs) { + this.lookup = lookup; + this.enumDescs = enumDescs; + this.resolvedEnum = new Object[enumDescs.length]; + } + + @Override + public boolean test(Integer labelIndex, Object value) { + Object result = resolvedEnum[labelIndex]; + + if (result == null) { + try { + if (!(value instanceof Enum enumValue)) { + return false; + } + + EnumDesc label = enumDescs[labelIndex]; + Class clazz = label.constantType().resolveConstantDesc(lookup); + + if (enumValue.getDeclaringClass() != clazz) { + return false; + } + + result = label.resolveConstantDesc(lookup); + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + result = SENTINEL; + } + + resolvedEnum[labelIndex] = result; + } + + return result == value; + } } private static final class EnumMap { @Stable public int[] map; } + + /* + * Construct test chains for labels inside switch, to handle switch repeats: + * switch (idx) { + * case 0 -> if (selector matches label[0]) return 0; + * case 1 -> if (selector matches label[1]) return 1; + * ... + * } + */ + @SuppressWarnings("removal") + private static MethodHandle generateInnerClass(MethodHandles.Lookup caller, Object[] labels) { + List> enumDescs = new ArrayList<>(); + List> extraClassLabels = new ArrayList<>(); + + byte[] classBytes = Classfile.of().build(ClassDesc.of(typeSwitchClassName(caller.lookupClass())), clb -> { + clb.withFlags(AccessFlag.FINAL, AccessFlag.SUPER, AccessFlag.SYNTHETIC) + .withMethodBody("typeSwitch", + TYPES_SWITCH_DESCRIPTOR, + Classfile.ACC_FINAL | Classfile.ACC_PUBLIC | Classfile.ACC_STATIC, + cb -> { + cb.aload(0); + Label nonNullLabel = cb.newLabel(); + cb.if_nonnull(nonNullLabel); + cb.iconst_m1(); + cb.ireturn(); + cb.labelBinding(nonNullLabel); + if (labels.length == 0) { + cb.constantInstruction(0) + .ireturn(); + return ; + } + cb.iload(1); + Label dflt = cb.newLabel(); + record Element(Label target, Label next, Object caseLabel) {} + List cases = new ArrayList<>(); + List switchCases = new ArrayList<>(); + Object lastLabel = null; + for (int idx = labels.length - 1; idx >= 0; idx--) { + Object currentLabel = labels[idx]; + Label target = cb.newLabel(); + Label next; + if (lastLabel == null) { + next = dflt; + } else if (lastLabel.equals(currentLabel)) { + next = cases.getLast().next(); + } else { + next = cases.getLast().target(); + } + lastLabel = currentLabel; + cases.add(new Element(target, next, currentLabel)); + switchCases.add(SwitchCase.of(idx, target)); + } + cases = cases.reversed(); + switchCases = switchCases.reversed(); + cb.tableswitch(0, labels.length - 1, dflt, switchCases); + for (int idx = 0; idx < cases.size(); idx++) { + Element element = cases.get(idx); + Label next = element.next(); + cb.labelBinding(element.target()); + if (element.caseLabel() instanceof Class classLabel) { + Optional classLabelConstableOpt = classLabel.describeConstable(); + if (classLabelConstableOpt.isPresent()) { + cb.aload(0); + cb.instanceof_(classLabelConstableOpt.orElseThrow()); + cb.ifeq(next); + } else { + cb.aload(3); + cb.constantInstruction(extraClassLabels.size()); + cb.invokeinterface(ConstantDescs.CD_List, + "get", + MethodTypeDesc.of(ConstantDescs.CD_Object, + ConstantDescs.CD_int)); + cb.checkcast(ConstantDescs.CD_Class); + cb.aload(0); + cb.invokevirtual(ConstantDescs.CD_Class, + "isInstance", + MethodTypeDesc.of(ConstantDescs.CD_boolean, + ConstantDescs.CD_Object)); + cb.ifeq(next); + extraClassLabels.add(classLabel); + } + } else if (element.caseLabel() instanceof EnumDesc enumLabel) { + int enumIdx = enumDescs.size(); + enumDescs.add(enumLabel); + cb.aload(2); + cb.constantInstruction(enumIdx); + cb.invokestatic(ConstantDescs.CD_Integer, + "valueOf", + MethodTypeDesc.of(ConstantDescs.CD_Integer, + ConstantDescs.CD_int)); + cb.aload(0); + cb.invokeinterface(BiPredicate.class.describeConstable().orElseThrow(), + "test", + MethodTypeDesc.of(ConstantDescs.CD_boolean, + ConstantDescs.CD_Object, + ConstantDescs.CD_Object)); + cb.ifeq(next); + } else if (element.caseLabel() instanceof String stringLabel) { + cb.ldc(stringLabel); + cb.aload(0); + cb.invokevirtual(ConstantDescs.CD_Object, + "equals", + MethodTypeDesc.of(ConstantDescs.CD_boolean, + ConstantDescs.CD_Object)); + cb.ifeq(next); + } else if (element.caseLabel() instanceof Integer integerLabel) { + Label compare = cb.newLabel(); + Label notNumber = cb.newLabel(); + cb.aload(0); + cb.instanceof_(ConstantDescs.CD_Number); + cb.ifeq(notNumber); + cb.aload(0); + cb.checkcast(ConstantDescs.CD_Number); + cb.invokevirtual(ConstantDescs.CD_Number, + "intValue", + MethodTypeDesc.of(ConstantDescs.CD_int)); + cb.goto_(compare); + cb.labelBinding(notNumber); + cb.aload(0); + cb.instanceof_(ConstantDescs.CD_Character); + cb.ifeq(next); + cb.aload(0); + cb.checkcast(ConstantDescs.CD_Character); + cb.invokevirtual(ConstantDescs.CD_Character, + "charValue", + MethodTypeDesc.of(ConstantDescs.CD_char)); + cb.labelBinding(compare); + cb.ldc(integerLabel); + cb.if_icmpne(next); + } else { + throw new InternalError("Unsupported label type: " + + element.caseLabel().getClass()); + } + cb.constantInstruction(idx); + cb.ireturn(); + } + cb.labelBinding(dflt); + cb.constantInstruction(cases.size()); + cb.ireturn(); + }); + }); + + try { + // this class is linked at the indy callsite; so define a hidden nestmate + MethodHandles.Lookup lookup; + lookup = caller.defineHiddenClass(classBytes, true, NESTMATE, STRONG); + MethodHandle typeSwitch = lookup.findStatic(lookup.lookupClass(), + "typeSwitch", + MethodType.methodType(int.class, + Object.class, + int.class, + BiPredicate.class, + List.class)); + return MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(EnumDesc[]::new)), + List.copyOf(extraClassLabels)); + } catch (Throwable t) { + throw new IllegalArgumentException(t); + } + } + + //based on src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java: + private static String typeSwitchClassName(Class targetClass) { + String name = targetClass.getName(); + if (targetClass.isHidden()) { + // use the original class name + name = name.replace('/', '_'); + } + return name + "$$TypeSwitch"; + } } diff --git a/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java b/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java index 1489a6d55b7..8e31f35bef5 100644 --- a/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java +++ b/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java @@ -24,12 +24,16 @@ import java.io.Serializable; import java.lang.Enum.EnumDesc; import java.lang.constant.ClassDesc; +import java.lang.constant.ConstantDescs; +import java.lang.constant.MethodTypeDesc; import java.lang.invoke.CallSite; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.lang.reflect.AccessFlag; import java.lang.runtime.SwitchBootstraps; import java.util.concurrent.atomic.AtomicBoolean; +import jdk.internal.classfile.Classfile; import org.testng.annotations.Test; @@ -42,6 +46,7 @@ import static org.testng.Assert.fail; * @test * @bug 8318144 * @enablePreview + * @modules java.base/jdk.internal.classfile * @compile SwitchBootstrapsTest.java * @run testng/othervm SwitchBootstrapsTest */ @@ -113,9 +118,12 @@ public class SwitchBootstrapsTest { } catch (IllegalArgumentException ex) { //OK } - testType("", 0, 0, String.class, String.class, String.class); - testType("", 1, 1, String.class, String.class, String.class); - testType("", 2, 2, String.class, String.class, String.class); + testType("", 0, 0, String.class, String.class, String.class, String.class, String.class); + testType("", 1, 1, String.class, String.class, String.class, String.class, String.class); + testType("", 2, 2, String.class, String.class, String.class, String.class, String.class); + testType("", 3, 3, String.class, String.class, String.class, String.class, String.class); + testType("", 3, 3, String.class, String.class, String.class, String.class, String.class); + testType("", 4, 4, String.class, String.class, String.class, String.class, String.class); testType("", 0, 0); } @@ -346,4 +354,32 @@ public class SwitchBootstrapsTest { } } + public void testHiddenClassAsCaseLabel() throws Throwable { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + byte[] classBytes = createClass(); + Class classA = lookup.defineHiddenClass(classBytes, false).lookupClass(); + Class classB = lookup.defineHiddenClass(classBytes, false).lookupClass(); + Object[] labels = new Object[] { + classA, + classB, + }; + testType(classA.getConstructor().newInstance(), 0, 0, labels); + testType(classB.getConstructor().newInstance(), 0, 1, labels); + } + + private static byte[] createClass() { + return Classfile.of().build(ClassDesc.of("C"), clb -> { + clb.withFlags(AccessFlag.SYNTHETIC) + .withMethodBody("", + MethodTypeDesc.of(ConstantDescs.CD_void), + Classfile.ACC_PUBLIC, + cb -> { + cb.aload(0); + cb.invokespecial(ConstantDescs.CD_Object, + "", + MethodTypeDesc.of(ConstantDescs.CD_void)); + cb.return_(); + }); + }); + } }