From ca405d0eb2a0ed63dc169aceb80512bf2a523da1 Mon Sep 17 00:00:00 2001 From: Roland Westrelin Date: Thu, 30 Apr 2026 07:44:14 +0000 Subject: [PATCH] 8376400: C2: folding ifs may cause incorrect execution when trap is taken Reviewed-by: chagedorn, aseoane --- src/hotspot/share/opto/callnode.hpp | 15 +- src/hotspot/share/opto/cfgnode.hpp | 1 + src/hotspot/share/opto/ifnode.cpp | 55 ++++ src/hotspot/share/opto/split_if.cpp | 2 +- .../rangechecks/TestFoldedIfsWrongReexec.java | 299 ++++++++++++++++++ 5 files changed, 369 insertions(+), 3 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java diff --git a/src/hotspot/share/opto/callnode.hpp b/src/hotspot/share/opto/callnode.hpp index e4c548fc744..6bf61730ca0 100644 --- a/src/hotspot/share/opto/callnode.hpp +++ b/src/hotspot/share/opto/callnode.hpp @@ -879,11 +879,14 @@ public: // calls and optimized virtual calls, plus calls to wrappers for run-time // routines); generates static stub. class CallStaticJavaNode : public CallJavaNode { + // If this is an uncommon trap guarded by some condition, is it safe to change the condition to a narrower condition? + // See comment in PhaseIdealLoop::do_split_if() + bool _safe_for_fold_compare; virtual bool cmp( const Node &n ) const; virtual uint size_of() const; // Size is bigger public: CallStaticJavaNode(Compile* C, const TypeFunc* tf, address addr, ciMethod* method) - : CallJavaNode(tf, addr, method) { + : CallJavaNode(tf, addr, method), _safe_for_fold_compare(true) { init_class_id(Class_CallStaticJava); if (C->eliminate_boxing() && (method != nullptr) && method->is_boxing_method()) { init_flags(Flag_is_macro); @@ -891,7 +894,7 @@ public: } } CallStaticJavaNode(const TypeFunc* tf, address addr, const char* name, const TypePtr* adr_type) - : CallJavaNode(tf, addr, nullptr) { + : CallJavaNode(tf, addr, nullptr), _safe_for_fold_compare(true) { init_class_id(Class_CallStaticJava); // This node calls a runtime stub, which often has narrow memory effects. _adr_type = adr_type; @@ -915,6 +918,14 @@ public: virtual int Opcode() const; virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); + void clear_safe_for_fold_compare() { + _safe_for_fold_compare = false; + } + + bool safe_for_fold_compare() const { + return _safe_for_fold_compare; + } + #ifndef PRODUCT virtual void dump_spec(outputStream *st) const; virtual void dump_compact_spec(outputStream *st) const; diff --git a/src/hotspot/share/opto/cfgnode.hpp b/src/hotspot/share/opto/cfgnode.hpp index 6af2972e688..5f5e68255db 100644 --- a/src/hotspot/share/opto/cfgnode.hpp +++ b/src/hotspot/share/opto/cfgnode.hpp @@ -477,6 +477,7 @@ public: #endif bool same_condition(const Node* dom, PhaseIterGVN* igvn) const; + void mark_projections_unsafe_for_fold_compare() const; }; class RangeCheckNode : public IfNode { diff --git a/src/hotspot/share/opto/ifnode.cpp b/src/hotspot/share/opto/ifnode.cpp index ad8f0ced6ea..9f99874d76a 100644 --- a/src/hotspot/share/opto/ifnode.cpp +++ b/src/hotspot/share/opto/ifnode.cpp @@ -879,6 +879,10 @@ bool IfNode::has_only_uncommon_traps(IfProjNode* proj, IfProjNode*& success, IfP return false; } + if (!dom_unc->safe_for_fold_compare()) { + return false; + } + // See merge_uncommon_traps: the reason of the uncommon trap // will be changed and the state of the dominating If will be // used. Checked that we didn't apply this transformation in a @@ -1671,6 +1675,57 @@ bool IfNode::same_condition(const Node* dom, PhaseIterGVN* igvn) const { return true; } +void IfNode::mark_projections_unsafe_for_fold_compare() const { + // With the following code pattern + // + // if (some_condition) { + // v = 0; + // } else { + // v = 1; + // } // v is Phi(0, 1) + // if (v == 0) { + // uncommon_trap(); // reexecutes the "if (v == 0) {" above, captures v as stack argument to ifeq bytecode + // } + // if (some_other_condition) { + // uncommon_trap(); // reexecutes the "if (some_other_condition) {" + // } + // + // if the second if is split thru Phi, the result is: + // + // if (some_condition) { + // uncommon_trap(); // reexecutes the "if (v == 0) {" that was removed above, captures v = 0 as stack argument to ifeq bytecode + // } + // if (some_other_condition) { + // uncommon_trap(); // reexecutes the "if (some_other_condition) {" + // } + // + // some_condition and some_other_condition could be folded into + // a single new condition that is narrower than some_condition + // (done by IfNode::fold_compares(), for instance): + // + // if (combined_narrower_condition) { + // uncommon_trap(); // reexecutes the "if (v == 0) {" that was removed, captures v = 0 as stack argument to ifeq bytecode + // } + // + // Then combined_narrower_condition is true for some input value for + // which some_condition is false. When such an input value is used + // at runtime, the trap is taken which causes "if (v == 0) {" to be + // reexecuted with v = 0 even though some_condition is wrong, causing + // the wrong branch to be executed. + // + // Mark the uncommon trap nodes to prevent such a transformation + // from happening. + IfProjNode* true_projection = true_proj(); + IfProjNode* false_projection = false_proj(); + CallStaticJavaNode* unc = true_projection->is_uncommon_trap_proj(); + if (unc != nullptr) { + unc->clear_safe_for_fold_compare(); + } + unc = false_projection->is_uncommon_trap_proj(); + if (unc != nullptr) { + unc->clear_safe_for_fold_compare(); + } +} static int subsuming_bool_test_encode(Node*); diff --git a/src/hotspot/share/opto/split_if.cpp b/src/hotspot/share/opto/split_if.cpp index e5f8043ae19..c8f25f92de3 100644 --- a/src/hotspot/share/opto/split_if.cpp +++ b/src/hotspot/share/opto/split_if.cpp @@ -615,7 +615,7 @@ void PhaseIdealLoop::handle_use( Node *use, Node *def, small_cache *cache, Node // Found an If getting its condition-code input from a Phi in the same block. // Split thru the Region. void PhaseIdealLoop::do_split_if(Node* iff, RegionNode** new_false_region, RegionNode** new_true_region) { - + iff->as_If()->mark_projections_unsafe_for_fold_compare(); C->set_major_progress(); RegionNode *region = iff->in(0)->as_Region(); Node *region_dom = idom(region); diff --git a/test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java b/test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java new file mode 100644 index 00000000000..9e77ceca0dd --- /dev/null +++ b/test/hotspot/jtreg/compiler/rangechecks/TestFoldedIfsWrongReexec.java @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2026 IBM Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** + * @test + * @bug 8376400 + * @summary C2: folding ifs may cause incorrect execution when trap is taken + * + * @run main/othervm -XX:-TieredCompilation -XX:-UseOnStackReplacement -XX:-BackgroundCompilation + * -XX:+UnlockDiagnosticVMOptions -XX:-OptimizeUnstableIf ${test.main.class} + * @run main ${test.main.class} + * + */ + +package compiler.rangechecks; + +public class TestFoldedIfsWrongReexec { + private static int taken1; + private static int taken2; + private static int taken3; + private static int taken4; + private static int taken5; + private static int taken6; + private static int taken7; + private static int MIN_VALUE = Integer.MIN_VALUE; + + public static void main(String[] args) { + for (int i = 0; i < 20_000; i++) { + test1(12); + if (taken1 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test1Helper1(16, 0); + test2(12); + if (taken2 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test2Helper1(16, 0); + test3(12); + if (taken3 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test3Helper1(16, 0); + test4(12, 1, 2); + if (taken4 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test4Helper1(16, 0, 1, 2); + test5(12); + if (taken5 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test5Helper1(16, 0); + test6(12, 1, 2); + if (taken6 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test6Helper1(16, 0, 1, 2); + test7(12); + if (taken7 != 0) { + throw new RuntimeException("branch shouldn't have been taken"); + } + test7Helper1(16, 0); + test7Helper2(o1); + test7Helper2(a); + test7Helper2(b); + } + test1(0); + if (taken1 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test2(0); + if (taken2 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test3(0); + if (taken3 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test4(0, 1, 2); + if (taken4 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test5(0); + if (taken5 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test6(0, 1, 2); + if (taken6 == 0) { + throw new RuntimeException("branch should have been taken"); + } + test7(0); + if (taken7 == 0) { + throw new RuntimeException("branch should have been taken"); + } + } + + private static void test1(int i) { + if (test1Helper1(i, 16) == 0) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken1++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test1Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return 0; + } + return 1; + } + + private static void test2(int i) { + if (test2Helper1(i, 16) == 42) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken2++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test2Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return 42; + } + return 0x42; + } + + private static void test3(int i) { + if (test3Helper1(i, 16) == 42L) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken3++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static long test3Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return 42L; + } + return 0x42L; + } + + private static void test4(int i, int x, int y) { + if (x == y) { + throw new RuntimeException("never taken"); + } + if (test4Helper1(i, 16, x, y) == y) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken4++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test4Helper1(int i, int j, int x, int y) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return y; + } + return x; + } + + static final Object o1 = new Object(); + static final Object o2 = new Object(); + + private static void test5(int i) { + if (test5Helper1(i, 16) == o1) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken5++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static Object test5Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return o1; + } + return o2; + } + + private static void test6(int i, int x, int y) { + if (x < y) { + if (test6Helper1(i, 16, x, y) < y) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken6++; + } + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static int test6Helper1(int i, int j, int x, int y) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return x; + } + return y; + } + + static final Object a = new A(); + static final Object b = new B(); + + private static void test7(int i) { + if (test7Helper2(test7Helper1(i, 16))) { + throw new RuntimeException("never taken"); + } + if (i + MIN_VALUE < 8 + Integer.MIN_VALUE) { + taken7++; + } + for (int j = 0; j < 10; j++) { + for (int k = 0; k < 10; k++) { + + } + } + } + + private static Object test7Helper1(int i, int j) { + if (i + MIN_VALUE >= j + Integer.MIN_VALUE) { + for (int k = 0; k < 100; k++) { + } + return a; + } + return b; + } + + private static boolean test7Helper2(Object o) { + return o instanceof A; + } + + private static class A { + } + + private static class B { + } +}