From dc55a7fc877ab5ea4efbed90454194008143aeb4 Mon Sep 17 00:00:00 2001 From: Jan Lahoda Date: Fri, 17 Feb 2023 12:55:47 +0000 Subject: [PATCH] 8302202: Incorrect desugaring of null-allowed nested patterns Reviewed-by: vromero --- .../sun/tools/javac/comp/TransPatterns.java | 25 ++- .../NullsInDeconstructionPatterns2.java | 163 ++++++++++++++++++ 2 files changed, 180 insertions(+), 8 deletions(-) create mode 100644 test/langtools/tools/javac/patterns/NullsInDeconstructionPatterns2.java diff --git a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java index 3463439a8a2..d6070c86ec8 100644 --- a/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java +++ b/src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java @@ -232,8 +232,8 @@ public class TransPatterns extends TreeTranslator { } Type principalType = types.erasure(TreeInfo.primaryPatternType((pattern))); - JCExpression resultExpression= (JCExpression) this.translate(pattern); - if (!tree.allowNull || !types.isSubtype(currentValue.type, principalType)) { + JCExpression resultExpression = (JCExpression) this.translate(pattern); + if (!tree.allowNull && !principalType.isPrimitive()) { resultExpression = makeBinary(Tag.AND, makeTypeTest(make.Ident(currentValue), make.Type(principalType)), @@ -330,7 +330,8 @@ public class TransPatterns extends TreeTranslator { allowNull = false; } else { nestedBinding = (JCBindingPattern) nestedPattern; - allowNull = true; + allowNull = types.isSubtype(componentType, + types.boxedTypeOrType(types.erasure(nestedBinding.type))); } JCMethodInvocation componentAccessor = make.App(make.Select(convert(make.Ident(recordBinding), recordBinding.type), //TODO - cast needed???? @@ -710,6 +711,7 @@ public class TransPatterns extends TreeTranslator { "commonNestedExpression: " + commonNestedExpression + "commonNestedBinding: " + commonNestedBinding); ListBuffer nestedCases = new ListBuffer<>(); + JCExpression lastGuard = null; for(List accList = accummulator.toList(); accList.nonEmpty(); accList = accList.tail) { var accummulated = accList.head; @@ -734,8 +736,6 @@ public class TransPatterns extends TreeTranslator { JCBindingPattern binding = (JCBindingPattern) instanceofCheck.pattern; hasUnconditional = instanceofCheck.allowNull && - types.isSubtype(commonNestedExpression.type, - types.boxedTypeOrType(types.erasure(binding.type))) && accList.tail.isEmpty(); List newLabel; if (hasUnconditional) { @@ -746,13 +746,16 @@ public class TransPatterns extends TreeTranslator { } appendBreakIfNeeded(currentSwitch, accummulated); nestedCases.add(make.Case(CaseKind.STATEMENT, newLabel, accummulated.stats, null)); + lastGuard = newGuard; } - if (!hasUnconditional) { + if (lastGuard != null || !hasUnconditional) { JCContinue continueSwitch = make.Continue(null); continueSwitch.target = currentSwitch; nestedCases.add(make.Case(CaseKind.STATEMENT, - List.of(make.ConstantCaseLabel(makeNull()), - make.DefaultCaseLabel()), + hasUnconditional + ? List.of(make.DefaultCaseLabel()) + : List.of(make.ConstantCaseLabel(makeNull()), + make.DefaultCaseLabel()), List.of(continueSwitch), null)); } @@ -774,9 +777,11 @@ public class TransPatterns extends TreeTranslator { VarSymbol commonBinding = null; JCExpression commonNestedExpression = null; VarSymbol commonNestedBinding = null; + boolean previousNullable = false; for (List c = inputCases; c.nonEmpty(); c = c.tail) { VarSymbol currentBinding = null; + boolean currentNullable = false; JCExpression currentNestedExpression = null; VarSymbol currentNestedBinding = null; @@ -786,11 +791,13 @@ public class TransPatterns extends TreeTranslator { binOp.lhs instanceof JCInstanceOf instanceofCheck && instanceofCheck.pattern instanceof JCBindingPattern binding) { currentBinding = ((JCBindingPattern) patternLabel.pat).var.sym; + currentNullable = instanceofCheck.allowNull; currentNestedExpression = instanceofCheck.expr; currentNestedBinding = binding.var.sym; } else if (patternLabel.guard instanceof JCInstanceOf instanceofCheck && instanceofCheck.pattern instanceof JCBindingPattern binding) { currentBinding = ((JCBindingPattern) patternLabel.pat).var.sym; + currentNullable = instanceofCheck.allowNull; currentNestedExpression = instanceofCheck.expr; currentNestedBinding = binding.var.sym; } @@ -806,6 +813,7 @@ public class TransPatterns extends TreeTranslator { } } else if (currentBinding != null && commonBinding.type.tsym == currentBinding.type.tsym && + !previousNullable && new TreeDiffer(List.of(commonBinding), List.of(currentBinding)) .scan(commonNestedExpression, currentNestedExpression)) { accummulator.add(c.head); @@ -820,6 +828,7 @@ public class TransPatterns extends TreeTranslator { commonNestedExpression = currentNestedExpression; commonNestedBinding = currentNestedBinding; } + previousNullable = currentNullable; } resolveAccummulator.resolve(commonBinding, commonNestedExpression, commonNestedBinding); return result.toList(); diff --git a/test/langtools/tools/javac/patterns/NullsInDeconstructionPatterns2.java b/test/langtools/tools/javac/patterns/NullsInDeconstructionPatterns2.java new file mode 100644 index 00000000000..1ba2d04c0b9 --- /dev/null +++ b/test/langtools/tools/javac/patterns/NullsInDeconstructionPatterns2.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8302202 + * @summary Testing record patterns with null components + * @enablePreview + * @compile NullsInDeconstructionPatterns2.java + * @run main NullsInDeconstructionPatterns2 + */ + +import java.util.Objects; +import java.util.function.Function; + +public class NullsInDeconstructionPatterns2 { + + public static void main(String[] args) { + new NullsInDeconstructionPatterns2().run(); + } + + private void run() { + run1(this::test1a); + run1(this::test1b); + run2(this::test2a); + run2(this::test2b); + run3(this::test3a); + run3(this::test3b); + run4(); + } + + private void run1(Function method) { + assertEquals("R1(null)", method.apply(new R1(null))); + assertEquals("R1(!null)", method.apply(new R1(""))); + } + + private void run2(Function method) { + assertEquals("R2(null, null)", method.apply(new R2(null, null))); + assertEquals("R2(!null, null)", method.apply(new R2("", null))); + assertEquals("R2(null, !null)", method.apply(new R2(null, ""))); + assertEquals("R2(!null, !null)", method.apply(new R2("", ""))); + } + + private void run3(Function method) { + assertEquals("R3(null, null, null)", method.apply(new R3(null, null, null))); + assertEquals("R3(!null, null, null)", method.apply(new R3("", null, null))); + assertEquals("R3(null, !null, null)", method.apply(new R3(null, "", null))); + assertEquals("R3(!null, !null, null)", method.apply(new R3("", "", null))); + assertEquals("R3(null, null, !null)", method.apply(new R3(null, null, ""))); + assertEquals("R3(!null, null, !null)", method.apply(new R3("", null, ""))); + assertEquals("R3(null, !null, !null)", method.apply(new R3(null, "", ""))); + assertEquals("R3(!null, !null, !null)", method.apply(new R3("", "", ""))); + } + + private void run4() { + assertEquals("integer", test4(new R1(0))); + assertEquals("empty", test4(new R1(""))); + assertEquals("default", test4(new R1("a"))); + } + private String test1a(Object i) { + return switch (i) { + case R1(Object o) when o == null -> "R1(null)"; + case R1(Object o) when o != null -> "R1(!null)"; + default -> "default"; + }; + } + + private String test1b(Object i) { + return switch (i) { + case R1(Object o) when o == null -> "R1(null)"; + case R1(Object o) -> "R1(!null)"; + default -> "default"; + }; + } + + private String test2a(Object i) { + return switch (i) { + case R2(Object o1, Object o2) when o1 == null && o2 == null -> "R2(null, null)"; + case R2(Object o1, Object o2) when o1 != null && o2 == null -> "R2(!null, null)"; + case R2(Object o1, Object o2) when o1 == null && o2 != null -> "R2(null, !null)"; + case R2(Object o1, Object o2) when o1 != null && o2 != null -> "R2(!null, !null)"; + default -> "default"; + }; + } + + private String test2b(Object i) { + return switch (i) { + case R2(Object o1, Object o2) when o1 == null && o2 == null -> "R2(null, null)"; + case R2(Object o1, Object o2) when o1 != null && o2 == null -> "R2(!null, null)"; + case R2(Object o1, Object o2) when o1 == null && o2 != null -> "R2(null, !null)"; + case R2(Object o1, Object o2) -> "R2(!null, !null)"; + default -> "default"; + }; + } + + private String test3a(Object i) { + return switch (i) { + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 == null -> "R3(null, null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 == null -> "R3(!null, null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 == null -> "R3(null, !null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 != null && o3 == null -> "R3(!null, !null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 != null -> "R3(null, null, !null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 != null -> "R3(!null, null, !null)"; + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 != null -> "R3(null, !null, !null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 != null && o3 != null -> "R3(!null, !null, !null)"; + default -> "default"; + }; + } + + private String test3b(Object i) { + return switch (i) { + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 == null -> "R3(null, null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 == null -> "R3(!null, null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 == null -> "R3(null, !null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 != null && o3 == null -> "R3(!null, !null, null)"; + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 == null && o3 != null -> "R3(null, null, !null)"; + case R3(Object o1, Object o2, Object o3) when o1 != null && o2 == null && o3 != null -> "R3(!null, null, !null)"; + case R3(Object o1, Object o2, Object o3) when o1 == null && o2 != null && o3 != null -> "R3(null, !null, !null)"; + case R3(Object o1, Object o2, Object o3) -> "R3(!null, !null, !null)"; + default -> "default"; + }; + } + + private String test4(Object i) { + return switch (i) { + case R1(Integer o) -> "integer"; + case R1(Object o) when o.toString().isEmpty() -> "empty"; + default -> "default"; + }; + } + + private static void assertEquals(String expected, String actual) { + if (!Objects.equals(expected, actual)) { + throw new AssertionError("Unexpected result, expected: " + expected + "," + + " actual: " + actual); + } + } + + record R1(Object o) {} + record R2(Object o1, Object o2) {} + record R3(Object o1, Object o2, Object o3) {} + +}