From 62843fcdcec870a920dadd30b5a9c4a695892ecf Mon Sep 17 00:00:00 2001 From: Emanuel Peter Date: Mon, 1 Jun 2026 06:56:33 +0000 Subject: [PATCH] 8346420: C2: IfNode::fold_compares_helper() wrongly folds two CmpI nodes to a single CmpU node Reviewed-by: roland, qamai, galder --- src/hotspot/share/opto/ifnode.cpp | 727 ++++++++++++----- src/hotspot/share/opto/subnode.hpp | 7 +- .../rangechecks/TestFoldCompares.java | 346 +++++++++ .../rangechecks/TestFoldComparesFuzzer.java | 728 ++++++++++++++++++ 4 files changed, 1606 insertions(+), 202 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/rangechecks/TestFoldCompares.java create mode 100644 test/hotspot/jtreg/compiler/rangechecks/TestFoldComparesFuzzer.java diff --git a/src/hotspot/share/opto/ifnode.cpp b/src/hotspot/share/opto/ifnode.cpp index 9f99874d76a..347d63ef57c 100644 --- a/src/hotspot/share/opto/ifnode.cpp +++ b/src/hotspot/share/opto/ifnode.cpp @@ -654,6 +654,12 @@ Node* IfNode::up_one_dom(Node *curr, bool linear_only) { //------------------------------filtered_int_type-------------------------------- // Return a possibly more restrictive type for val based on condition control flow for an if +// +// Important: we only parse if val is on the lhs. This is a limitation, but it makes +// optimizations simpler. We rely on canonicalization to get us to this +// shape, which works well for comparisions with constants, as they are +// canonicalized to the rhs. This may not happen with variables, and so +// the optimization may not work for those cases, when val stays on the rhs. const TypeInt* IfNode::filtered_int_type(PhaseGVN* gvn, Node* val, Node* if_proj) { assert(if_proj && (if_proj->Opcode() == Op_IfTrue || if_proj->Opcode() == Op_IfFalse), "expecting an if projection"); @@ -663,11 +669,14 @@ const TypeInt* IfNode::filtered_int_type(PhaseGVN* gvn, Node* val, Node* if_proj BoolNode* bol = iff->in(1)->as_Bool(); if (bol->in(1) && bol->in(1)->is_Cmp()) { const CmpNode* cmp = bol->in(1)->as_Cmp(); + // Val is always the lhs of the comparision: val cmp2 if (cmp->in(1) == val) { + assert(cmp->Opcode() == Op_CmpI, "signed comparison required"); const TypeInt* cmp2_t = gvn->type(cmp->in(2))->isa_int(); if (cmp2_t != nullptr) { jint lo = cmp2_t->_lo; jint hi = cmp2_t->_hi; + // Negate the test if we are on the false branch. BoolTest::mask msk = if_proj->Opcode() == Op_IfTrue ? bol->_test._test : bol->_test.negate(); switch (msk) { case BoolTest::ne: { @@ -675,8 +684,12 @@ const TypeInt* IfNode::filtered_int_type(PhaseGVN* gvn, Node* val, Node* if_proj const TypeInt* val_t = gvn->type(val)->isa_int(); if (val_t != nullptr && !val_t->singleton() && cmp2_t->is_con()) { if (val_t->_lo == lo) { + // Condition leading to if_proj: val != val->lo + // val in [val->lo + 1, val->hi] return TypeInt::make(val_t->_lo + 1, val_t->_hi, val_t->_widen); } else if (val_t->_hi == hi) { + // Condition leading to if_proj: val != val->hi + // val in [val->lo, val->hi - 1] return TypeInt::make(val_t->_lo, val_t->_hi - 1, val_t->_widen); } } @@ -684,28 +697,38 @@ const TypeInt* IfNode::filtered_int_type(PhaseGVN* gvn, Node* val, Node* if_proj return nullptr; } case BoolTest::eq: + // Condition leading to if_proj: val == cmp2 + // val in cmp2_t return cmp2_t; case BoolTest::lt: - lo = TypeInt::INT->_lo; + // Condition leading to if_proj: val < cmp2 + // val in [min_int .. max(min_int, cmp2->_hi - 1)] + lo = min_jint; if (hi != min_jint) { hi = hi - 1; } break; case BoolTest::le: - lo = TypeInt::INT->_lo; + // Condition leading to if_proj: val <= cmp2 + // val in [min_int .. cmp2->_hi] + lo = min_jint; break; case BoolTest::gt: + // Condition leading to if_proj: val > cmp2 + // val in [min(cmp2->_lo + 1, max_int) .. max_int] if (lo != max_jint) { lo = lo + 1; } - hi = TypeInt::INT->_hi; + hi = max_jint; break; case BoolTest::ge: - // lo unchanged - hi = TypeInt::INT->_hi; + // Condition leading to if_proj: val >= cmp2 + // val in [cmp2->_lo .. max_int] + hi = max_jint; break; default: - break; + assert(false, "impossible case"); + return nullptr; } const TypeInt* rtn_t = TypeInt::make(lo, hi, cmp2_t->_widen); return rtn_t; @@ -902,219 +925,523 @@ bool IfNode::has_only_uncommon_traps(IfProjNode* proj, IfProjNode*& success, IfP return false; } -// Check that the 2 CmpI can be folded into as single CmpU and proceed with the folding -bool IfNode::fold_compares_helper(IfProjNode* proj, IfProjNode* success, IfProjNode* fail, PhaseIterGVN* igvn) { - Node* this_cmp = in(1)->in(1); - BoolNode* this_bool = in(1)->as_Bool(); - IfNode* dom_iff = proj->in(0)->as_If(); - BoolNode* dom_bool = dom_iff->in(1)->as_Bool(); - Node* lo = dom_iff->in(1)->in(1)->in(2); - Node* orig_lo = lo; - Node* hi = this_cmp->in(2); - Node* n = this_cmp->in(1); - IfProjNode* otherproj = proj->other_if_proj(); +// We are given the following code shape with two CmpI: +// +// n v1 +// | | +// cmp1 +// | +// entry bool1(test1) +// | | +// iff1 +// | \ +// middle fail1-------------+ +// | | +// | n v2 | +// | | | | +// maybe cmp2 | +// null-check | | +// | bool2(test2) | +// | | | +// iff2 | +// | \ v +// succ fail2----> go to same region +// or uncommon trap +// +// 1. In some cases, we can prove that succ cannot be reached, +// and we can fold away the iff2. Example: +// +// if (n < -1 && n > 1) { succ } else { fail } +// // 1st condition: n in [min_int .. -2] +// // 2nd condition: n in [2 .. max_int] +// // -> no overlap -> constant fold iff2 towards fail2 +// // +// // Equivalent, if we flip everything: +// if (n >= -1 || n <= 1) { fail } else { succ } +// +// 2. In other cases, we can replace the two CmpI with +// a single CmpU. We fold iff1 towards middle, and +// replace the iff2 condition with the CmpU. Example: +// +// if (n >= 0 && n < 10) { succ } else { fail } +// // transformed to: +// if (n = arr.length) { throw ArrayOutOfBoundsException } +// // transformed to: +// if (n >=u arr.length) { throw ArrayOutOfBoundsException } +// +// Note1: we assume that the CmpI nodes are canonicalized to the +// point where n is always on the lhs. This is a limitation, +// but as long as v1 and v2 are constants they will eventually +// be canonicalized to the rhs. For variables, this may not always +// happen. +// +// Note2: We are flexible about the IfProj nodes: middle and succ +// could both be either IfTrue or IfFalse. +// +// Note3: Surrounding code has a different naming scheme! +// In has_only_uncommon_traps, the path towards the +// uncommon trap (e.g. failed range check) is called +// "success", while the path that does not go to +// the uncommon trap (e.g. in-bounds access) is called +// "fail". I think that is counter-intuitive, so I now +// used a different naming scheme here. +// +// Return true iff we could perform one of the optimizations. +bool IfNode::fold_compares_helper(IfProjNode* middle, IfProjNode* fail2, IfProjNode* succ, PhaseIterGVN* igvn) { + assert(fail2->in(0) == this, "link iff2->fail2"); + assert(succ->in(0) == this, "link iff2->succ"); - const TypeInt* lo_type = IfNode::filtered_int_type(igvn, n, otherproj); - const TypeInt* hi_type = IfNode::filtered_int_type(igvn, n, success); + IfNode* iff1 = middle->in(0)->as_If(); + IfNode* iff2 = this; + BoolNode* bool1 = iff1->in(1)->as_Bool(); + BoolNode* bool2 = iff2->in(1)->as_Bool(); + CmpNode* cmp1 = bool1->in(1)->as_Cmp(); + CmpNode* cmp2 = bool2->in(1)->as_Cmp(); + assert(cmp1->Opcode() == Op_CmpI, "comparisons must be CmpI"); + assert(cmp2->Opcode() == Op_CmpI, "comparisons must be CmpI"); - BoolTest::mask lo_test = dom_bool->_test._test; - BoolTest::mask hi_test = this_bool->_test._test; - BoolTest::mask cond = hi_test; + IfProjNode* fail1 = middle->other_if_proj(); - PhaseTransform::SpeculativeProgressGuard progress_guard(igvn); - // convert: - // - // dom_bool = x {<,<=,>,>=} a - // / \ - // proj = {True,False} / \ otherproj = {False,True} - // / - // this_bool = x {<,<=} b - // / \ - // fail = {True,False} / \ success = {False,True} - // / - // - // (Second test guaranteed canonicalized, first one may not have - // been canonicalized yet) - // - // into: - // - // cond = (x - lo) {u,>=u} adjusted_lim - // / \ - // fail / \ success - // / - // + Node* v1 = cmp1->in(2); + Node* v2 = cmp2->in(2); + Node* n = cmp1->in(1); + assert(cmp2->in(1) == n, "n must be lhs in both CmpI"); - // Figure out which of the two tests sets the upper bound and which - // sets the lower bound if any. - Node* adjusted_lim = nullptr; - if (lo_type != nullptr && hi_type != nullptr && hi_type->_lo > lo_type->_hi && - hi_type->_hi == max_jint && lo_type->_lo == min_jint && lo_test != BoolTest::ne) { - assert((dom_bool->_test.is_less() && !proj->_con) || - (dom_bool->_test.is_greater() && proj->_con), "incorrect test"); - - // this_bool = < - // dom_bool = >= (proj = True) or dom_bool = < (proj = False) - // x in [a, b[ on the fail (= True) projection, b > a-1 (because of hi_type->_lo > lo_type->_hi test above): - // lo = a, hi = b, adjusted_lim = b-a, cond = (proj = True) or dom_bool = <= (proj = False) - // x in ]a, b[ on the fail (= True) projection, b > a: - // lo = a+1, hi = b, adjusted_lim = b-a-1, cond = = (proj = True) or dom_bool = < (proj = False) - // x in [a, b] on the fail (= True) projection, b+1 > a-1: - // lo = a, hi = b, adjusted_lim = b-a+1, cond = (proj = True) or dom_bool = <= (proj = False) - // x in ]a, b] on the fail (= True) projection b+1 > a: - // lo = a+1, hi = b, adjusted_lim = b-a, cond = transform(new AddINode(lo, igvn->intcon(1))); + // Optimization 1: try to prove that succ is not reachable. + // Which values of n can pass iff1 to middle AND iff2 to succ? + const TypeInt* type_middle = filtered_int_type(igvn, n, middle); + if (type_middle != nullptr) { + const TypeInt* type_succ = filtered_int_type(igvn, n, succ); + if (type_succ != nullptr) { + if (type_middle->filter(type_succ) == Type::TOP) { + // The intersection is empty -> succ is not reachable. + // Fold iff2 towards fail2 (and away from succ). + igvn->replace_input_of(iff2, 1, igvn->intcon(fail2->_con)); + return true; // success: succ not reachable } - } else if (hi_test == BoolTest::le) { - if (lo_test == BoolTest::ge || lo_test == BoolTest::lt) { - adjusted_lim = igvn->transform(new SubINode(hi, lo)); - adjusted_lim = igvn->transform(new AddINode(adjusted_lim, igvn->intcon(1))); - cond = BoolTest::lt; - } else if (lo_test == BoolTest::gt || lo_test == BoolTest::le) { - adjusted_lim = igvn->transform(new SubINode(hi, lo)); - lo = igvn->transform(new AddINode(lo, igvn->intcon(1))); - cond = BoolTest::lt; - } else { - assert(false, "unhandled lo_test: %d", lo_test); - return false; - } - } else { - assert(igvn->_worklist.member(in(1)) && in(1)->Value(igvn) != igvn->type(in(1)), "unhandled hi_test: %d", hi_test); - return false; } - // this test was canonicalized - assert(this_bool->_test.is_less() && fail->_con, "incorrect test"); - } else if (lo_type != nullptr && hi_type != nullptr && lo_type->_lo > hi_type->_hi && - lo_type->_hi == max_jint && hi_type->_lo == min_jint && lo_test != BoolTest::ne) { + } - // this_bool = < - // dom_bool = < (proj = True) or dom_bool = >= (proj = False) - // x in [b, a[ on the fail (= False) projection, a > b-1 (because of lo_type->_lo > hi_type->_hi above): - // lo = b, hi = a, adjusted_lim = a-b, cond = >=u - // dom_bool = <= (proj = True) or dom_bool = > (proj = False) - // x in [b, a] on the fail (= False) projection, a+1 > b-1: - // lo = b, hi = a, adjusted_lim = a-b+1, cond = >=u - // lo = b, hi = a, adjusted_lim = a-b, cond = >u doesn't work because a = b - 1 is possible, then b-a = -1 - // this_bool = <= - // dom_bool = < (proj = True) or dom_bool = >= (proj = False) - // x in ]b, a[ on the fail (= False) projection, a > b: - // lo = b+1, hi = a, adjusted_lim = a-b-1, cond = >=u - // dom_bool = <= (proj = True) or dom_bool = > (proj = False) - // x in ]b, a] on the fail (= False) projection, a+1 > b: - // lo = b+1, hi = a, adjusted_lim = a-b, cond = >=u - // lo = b+1, hi = a, adjusted_lim = a-b-1, cond = >u doesn't work because a = b is possible, then b-a-1 = -1 + // Optimization 2: try to replace the two CmpI with one CmpU + // We can handle the following 4 cases: + // Input: two CmpI Output: one CmpU Assumption + // -------------------- ------------------------- ------------------- + // a) (n > lo && n < hi) -> n - lo - 1 2 && n < 5 ) n - 3 lo && n <= hi) -> n - lo - 1 2 && n <= 5 ) n - 3 = lo && n < hi) -> n - lo = 2 && n < 5 ) n - 2 = lo && n <= hi) -> n - lo <=u hi - lo (assuming lo <= hi) + // (n >= 2 && n <= 5 ) n - 2 <=u 3 + // range: [2, 3, 4, 5] + // + // Note1: the rhs of the CmpU indicates the cardinality of the range, + // allowing n to have exactly that many different values. + // + // Note2: all 4 case have an assumption: lo must be sufficiently smaller + // than hi. Below, and with the use of Lemma1 from below, we will + // prove that this implies that the rhs of the CmpU never + // underflows or overflows, which is critical for correctness. + // + // Below, we will prove and implement each of these cases. But first, + // we must handle the combinations of IfTrue/IfFalse projections for + // middle and succ, and extract which one is the lower bound (lo) and + // which one the upper bound (hi). + // + // <---- lower bound -----> <----------- succ -------------> <---- upper bound -----> + // [min_int .. lo_type->hi] [lo_type->hi+1 .. hi_type->lo-1] [hi_type->lo .. max_int] + // ^ ^ + // n {>/>=} lo n { <------------ unsigned upper bound -------------> + // [0 .. ] [ .. max_uint] + // ^ + // CmpU - swap(lo, hi); - swap(lo_type, hi_type); - swap(lo_test, hi_test); + BoolTest::mask test1 = bool1->_test._test; + BoolTest::mask test2 = bool2->_test._test; + if (middle->Opcode() == Op_IfFalse) { test1 = BoolTest::negate_mask(test1); } + if (succ->Opcode() == Op_IfFalse) { test2 = BoolTest::negate_mask(test2); } - assert((dom_bool->_test.is_less() && proj->_con) || - (dom_bool->_test.is_greater() && !proj->_con), "incorrect test"); - - cond = (hi_test == BoolTest::le || hi_test == BoolTest::gt) ? BoolTest::gt : BoolTest::ge; - - if (lo_test == BoolTest::lt) { - if (hi_test == BoolTest::lt || hi_test == BoolTest::ge) { - cond = BoolTest::ge; - } else if (hi_test == BoolTest::le || hi_test == BoolTest::gt) { - adjusted_lim = igvn->transform(new SubINode(hi, lo)); - adjusted_lim = igvn->transform(new AddINode(adjusted_lim, igvn->intcon(1))); - cond = BoolTest::ge; - } else { - assert(false, "unhandled hi_test: %d", hi_test); - return false; - } - } else if (lo_test == BoolTest::le) { - if (hi_test == BoolTest::lt || hi_test == BoolTest::ge) { - lo = igvn->transform(new AddINode(lo, igvn->intcon(1))); - cond = BoolTest::ge; - } else if (hi_test == BoolTest::le || hi_test == BoolTest::gt) { - adjusted_lim = igvn->transform(new SubINode(hi, lo)); - lo = igvn->transform(new AddINode(lo, igvn->intcon(1))); - cond = BoolTest::ge; - } else { - assert(false, "unhandled hi_test: %d", hi_test); - return false; - } - } else { - assert(igvn->_worklist.member(in(1)) && in(1)->Value(igvn) != igvn->type(in(1)), "unhandled lo_test: %d", lo_test); - return false; - } - // this test was canonicalized - assert(this_bool->_test.is_less() && !fail->_con, "incorrect test"); + Node* lo = nullptr; + Node* hi = nullptr; + const TypeInt* lo_type = nullptr; + const TypeInt* hi_type = nullptr; + BoolTest::mask lo_test = BoolTest::illegal; + BoolTest::mask hi_test = BoolTest::illegal; + if (BoolTest::is_greater(test1) && BoolTest::is_less(test2)) { + lo = v1; + hi = v2; + lo_type = IfNode::filtered_int_type(igvn, n, fail1); + hi_type = IfNode::filtered_int_type(igvn, n, fail2); + lo_test = test1; + hi_test = test2; + } else if (BoolTest::is_less(test1) && BoolTest::is_greater(test2)) { + lo = v2; + hi = v1; + lo_type = IfNode::filtered_int_type(igvn, n, fail2); + hi_type = IfNode::filtered_int_type(igvn, n, fail1); + lo_test = test2; + hi_test = test1; } else { - const TypeInt* failtype = filtered_int_type(igvn, n, proj); - if (failtype != nullptr) { - const TypeInt* type2 = filtered_int_type(igvn, n, fail); - if (type2 != nullptr) { - if (failtype->filter(type2) == Type::TOP) { - // previous if determines the result of this if so - // replace Bool with constant - igvn->replace_input_of(this, 1, igvn->intcon(success->_con)); - progress_guard.commit(); - return true; - } - } - } + // Could not find upper and lower bound. + return false; + } + assert(BoolTest::is_greater(lo_test), "lower bound: n {>/>=} lo"); + assert(BoolTest::is_less(hi_test), "upper bound: n {_hi != max_jint || + lo_type->_lo != min_jint) { + // Upper and lower bounds could not be established. return false; } - assert(lo != nullptr && hi != nullptr, "sanity"); - Node* hook = new Node(lo); // Add a use to lo to prevent him from dying - // Merge the two compares into a single unsigned compare by building (CmpU (n - lo) (hi - lo)) - Node* adjusted_val = igvn->transform(new SubINode(n, lo)); - if (adjusted_lim == nullptr) { - adjusted_lim = igvn->transform(new SubINode(hi, lo)); - } - hook->destruct(igvn); + // ------------------------------------------------------------------- + // In the proofs below, we need some basic Lemmas to deal with integer + // signed and unsigned arithmetic. + // + // Lemma1: + // Let a and b be in [min_int .. max_int]. + // If a >=s b, then: + // U(a - b) = a - b + // + // Proof: + // a >= b + // -> a - b >= 0 + // + // a <= max_int + // b >= min_int + // -> a - b <= max_int - min_int = 2^32-1 + // + // 0 <= a - b <= 2^32-1 + // -> cast to unsigned has no overflow + // -> U(a - b) = a - b + // + // Lemma2: + // Let a and b be in [min_int .. max_int]. + // If a a - b < 0 + // + // a >= min_int + // b <= max_int + // -> a - b >= min_int - max_int = 2^32-1 + // + // 2^32-1 <= a - b < 0 + // -> cast to unsigned leads to exactly one overflow + // -> U(a - b) = a - b + 2^32 + // + // Lemma3: + // Let a and b be in [min_int .. max_int]. + // a + 2^32 > b + // + // Proof: + // Using a >= min_int, and b <= max_int: + // a + 2^32 >= min_int + 2^32 + // = max_int + 1 + // >= b + 1 + // > b + // ------------------------------------------------------------------- - if (adjusted_val->is_top() || adjusted_lim->is_top()) { - return false; + // Handle the 4 cases. + // All produce this form: n - lo + x1 hi - lo + x2 + Node* x1 = nullptr; + Node* x2 = nullptr; + BoolTest::mask cond = BoolTest::illegal; + if (lo_test == BoolTest::gt && hi_test == BoolTest::lt) { + // We perform the the (CHECK) below, which implies (LO-HI), + // as we will show below. + if (lo_type->_hi >= hi_type->_lo) { + return false; // (CHECK) fails, we cannot establish (LO-HI) assumption. + } + // a) (n > lo && n < hi) -> n - lo - 1 _hi] for n <= lo + // -> lo_type->_hi = lo->_hi + // hi_type = [hi->_lo .. max_int] for n >= lo + // -> hi_type->_lo = hi->_lo + // We will need the assumption (LO-HI) below, which we can + // establish with the following (CHECK): + // lo_type->_hi < hi_type->_lo (CHECK) + // -> lo->_hi < hi->_lo + // -> lo < hi (LO-HI) + // + // Case n <= lo: + // (BEFORE) is always false, show (AFTER) is always false. + // Since lo < hi (LO-HI), S(lo+1) = lo+1 (no overflow): + // -> lo+1 <= hi + // -> n < lo+1 + // U(n - (lo + 1)) < U(hi - (lo + 1)) + // -- Lemma2 (n < lo+1) -- -- Lemma1 (lo+1 <= hi) -- + // n - (lo + 1) + 2^32 < hi - (lo + 1) + // n + 2^32 < hi + // Always false by Lemma3. + // + // Case lo < n < hi: + // (BEFORE) is always true, show (AFTER) is always true. + // Since lo < hi (LO-HI), S(lo+1) = lo+1 (no overflow): + // -> lo+1 <= hi + // -> n >= lo+1 + // U(n - (lo + 1)) < U(hi - (lo + 1)) + // -- Lemma1 (n >= lo+1) -- -- Lemma1 (lo+1 <= hi) -- + // n - (lo + 1) < hi - (lo + 1) + // n < hi + // Corresponds to case assumption, so always true. + // + // Case n >= hi: + // (BEFORE) is always false, show (AFTER) is always false. + // Since lo < hi (LO-HI), S(lo+1) = lo+1 (no overflow): + // -> lo+1 <= hi + // U(n - (lo + 1)) < U(hi - (lo + 1)) + // -- Lemma1 (n >= lo+1) -- -- Lemma1 (lo+1 <= hi) -- + // n - (lo + 1) < hi - (lo + 1) + // n < hi + // Contradicts case assumption, so always false. + // QED. + // + // Note: we cannot use anything more relaxed than the assumption + // lo < hi: with lo=hi the rhs of the CmpU would underflow. + // + // Produce form: n - lo + x1 hi - lo + x2 + // n - lo - 1 intcon(-1); + x2 = igvn->intcon(-1); + cond = BoolTest::lt; + } else if (lo_test == BoolTest::gt && hi_test == BoolTest::le) { + // We perform the the (CHECK) below, which implies (LO-HI), + // as we will show below. + if (lo_type->_hi >= hi_type->_lo) { + return false; // (CHECK) fails, we cannot establish (LO-HI) assumption. + } + // b) (n > lo && n <= hi) -> n - lo - 1 _hi] for n <= lo + // -> lo_type->_hi = lo->_hi + // hi_type = [min(hi->_lo+1, max_int) .. max_int] for n > hi + // -> hi_type->_lo <= lo->_lo + 1 + // We will need the assumption (LO-HI) below, which we can + // establish with the following (CHECK): + // lo_type->_hi < hi_type->_lo (CHECK) + // -> lo->_hi < hi->_lo + 1 + // -> lo < hi + 1 + // -> lo <= hi (LO-HI) + // + // Case A: lo = hi + // Let y = lo = hi + // -> n > lo && n <= hi vs n - lo - 1 n > y && n <= y vs n - y - 1 n < lo+1 + // U(n - (lo + 1)) < U(hi - lo) + // -- Lemma2 (n < lo+1) -- -- Lemma1 (lo <= hi, LO-HI) -- + // n - (lo + 1) + 2^32 < hi - lo + // n - 1 + 2^32 < hi + // n + 2^32 <= hi + // Always false by Lemma3. + // Note: To apply Lemma2 above, we must use (Case B), we + // could not have done it with (LO-HI) alone. + // + // Case lo < n <= hi: + // (BEFORE) is always true, show (AFTER) is always true. + // Since lo < hi (Case B), S(lo+1) = lo+1 (no overflow): + // -> n >= lo+1 + // U(n - (lo + 1)) < U(hi - lo) + // -- Lemma1 (n >= lo+1) -- -- Lemma1 (lo <= hi, LO-HI) -- + // n - (lo + 1) < hi - lo + // n - 1 < hi + // n <= hi + // Follows from case assumption, so always true. + // + // Case n > hi: + // (BEFORE) is always false, show (AFTER) is always false. + // Since lo < hi (Case B), S(lo+1) = lo+1 (no overflow): + // -> lo+1 <= hi + // -> n > lo+1 + // U(n - (lo + 1)) < U(hi - lo) + // -- Lemma1 (n > lo+1) -- -- Lemma1 (lo <= hi, LO-HI) -- + // n - (lo + 1) < hi - lo + // n - 1 < hi + // n <= hi + // Contradicts case assumption, so always false. + // QED. + // + // Note: we cannot use anything more relaxed than the assumption + // lo <= hi: with lo=hi+1 the rhs of the CmpU would underflow. + // + // Produce form: n - lo + x1 hi - lo + x2 + // n - lo - 1 intcon(-1); + x2 = igvn->intcon(0); + cond = BoolTest::lt; + } else if (lo_test == BoolTest::ge && hi_test == BoolTest::lt) { + // We perform the the (CHECK) below, which implies (LO-HI), + // as we will show below. + if (lo_type->_hi >= hi_type->_lo) { + return false; // (CHECK) fails, we cannot establish (LO-HI) assumption. + } + // c) (n >= lo && n < hi) -> n - lo _hi - 1)] for n < lo + // -> lo_type->_hi >= lo->_hi - 1 + // hi_type = [b->_lo .. max_int] for n >= hi + // -> hi_type->_lo = hi->_lo + // We will need the assumption (LO-HI) below, which we can + // establish with the following (CHECK): + // lo_type->_hi < hi_type->_lo + // -> lo->_hi - 1 < hi->_lo + // -> lo->_hi <= hi->_lo + // -> lo <= hi (HI-LO) + // + // Case n < lo: + // (BEFORE) is always false, show (AFTER) is always false. + // U(n - lo) < U(hi - lo) + // -- Lemma2 (n < lo) -- -- Lemma1 (lo <= hi, LO-HI) -- + // n - lo + 2^32 < hi - lo + // n + 2^32 < hi + // Always false by Lemma3. + // + // Case lo <=s n = lo) -- -- Lemma1 (lo <= hi, LO-HI) -- + // n - lo < hi - lo + // n < hi + // Follows from case assumption, so always true. + // + // Case n >=s hi: + // (BEFORE) is always false, show (AFTER) is always false. + // U(n - lo) < U(hi - lo) + // -- Lemma1 (n >= lo) -- -- Lemma1 (lo <= hi, LO-HI) -- + // n - lo < hi - lo + // n < hi + // Contradicts case assumption, so always false. + // QED. + // + /// Note: we cannot use anything more relaxed than the assumption + // lo <= hi: with lo=hi+1 the rhs of the CmpU would underflow. + // + // Produce form: n - lo + x1 hi - lo + x2 + // n - lo intcon(0); + x2 = igvn->intcon(0); + cond = BoolTest::lt; + } else { + assert (lo_test == BoolTest::ge && hi_test == BoolTest::le, ""); + // We perform the the (CHECK) below, which implies (LO-HI), + // as we will show below. + jlong lo_type_hi = lo_type->_hi; + jlong hi_type_lo = hi_type->_lo; + if (lo_type_hi >= hi_type_lo - 1) { + return false; // (CHECK) fails, we cannot establish (LO-HI) assumption. + } + // d) (n >= lo && n <= hi) -> n - lo <=u hi - lo (assuming lo <= hi) + // (BEFORE) (AFTER) (LO-HI) + // + // Proof: + // From IfNode::filtered_int_type, we get: + // lo_type = [min_int .. max(min_int, lo->_hi-1)] for n < lo + // -> lo_type->_hi >= lo->_hi - 1 + // hi_type = [min(hi->_lo+1, max_int) .. max_int] for n > hi + // -> hi_type->_lo <= hi->_lo + 1 + // We will need the assumption (LO-HI) below, which we can + // establish with the following (CHECK), which we must compute in + // long to avoid underflow: + // lo_type->_hi < hi_type->_lo - 1 (CHECK) + // -> lo_type->_hi + 1 <= hi_type->_lo - 1 + // -> lo->_hi <= hi->_lo + // -> lo <= hi (LO-HI) + // + // Case n = lo, LO-HI) -- + // n - lo + 2^32 <= hi - lo + // n + 2^32 <= hi + // Always false by Lemma3. + // + // Case lo <=s n <=s hi: + // (BEFORE) is always true, show (AFTER) is always true. + // U(n - lo) <= U(hi - lo) + // -- Lemma1 (n >= lo) -- -- Lemma1 (hi >= lo, LO-HI) -- + // n - lo <= hi - lo + // n <= hi + // Corresponds to case assumption, so always true. + // + // Case n >s hi: + // (BEFORE) is always false, show (AFTER) is always false. + // U(n - lo) <= U(hi - lo) + // -- Lemma1 (n > lo) -- -- Lemma1 (hi >= lo, LO-HI) -- + // n - lo <= hi - lo + // n <= hi + // n <= hi + // Contradicts case assumption, so always false. + // QED. + // + // Note: (CHECK) is stronger in this case than in (a, b, c). We have + // had multiple bugs around this case (d) in the past. For example: + // - Before JDK-8135069: transform into: n - lo <=u hi - lo + // leads to rhs underflow with lo=0 and hi=-1 + // -> we are coming back to this solution, but instead + // of checking lo_type->_hi < hi_type->_lo + // we now check: lo_type->_hi < hi_type->_lo - 1 + // which implies lo <= hi and excludes this bad case. + // - Before JDK-8346420: transform into: n - lo hi - lo + x2 + // n - lo <=u hi - lo + x1 = igvn->intcon(0); + x2 = igvn->intcon(0); + cond = BoolTest::le; } - if (igvn->type(adjusted_lim)->is_int()->_lo < 0 && - !igvn->C->post_loop_opts_phase()) { - // If range check elimination applies to this comparison, it includes code to protect from overflows that may - // cause the main loop to be skipped entirely. Delay this transformation. - // Example: - // for (int i = 0; i < limit; i++) { - // if (i < max_jint && i > min_jint) {... - // } - // Comparisons folded as: - // i - min_jint - 1 outcnt() == 0) { - igvn->remove_dead_node(lo, PhaseIterGVN::NodeOrigin::Speculative); - } - if (adjusted_val->outcnt() == 0) { - igvn->remove_dead_node(adjusted_val, PhaseIterGVN::NodeOrigin::Speculative); - } - if (adjusted_lim->outcnt() == 0) { - igvn->remove_dead_node(adjusted_lim, PhaseIterGVN::NodeOrigin::Speculative); - } - igvn->C->record_for_post_loop_opts_igvn(this); - return false; - } - - Node* newcmp = igvn->transform(new CmpUNode(adjusted_val, adjusted_lim)); + // Construct the new check: n - lo + x1 hi - lo + x2 + Node* lhs = igvn->transform(new SubINode(n, lo)); + lhs = igvn->transform(new AddINode(lhs, x1)); + Node* rhs = igvn->transform(new SubINode(hi, lo)); + rhs = igvn->transform(new AddINode(rhs, x2)); + Node* newcmp = igvn->transform(new CmpUNode(lhs, rhs)); + if (succ->Opcode() == Op_IfFalse) { cond = BoolTest::negate_mask(cond); } Node* newbool = igvn->transform(new BoolNode(newcmp, cond)); - igvn->replace_input_of(dom_iff, 1, igvn->intcon(proj->_con)); - igvn->replace_input_of(this, 1, newbool); + // Fold iff1 towards middle, and replace the iff2 condition: + igvn->replace_input_of(iff1, 1, igvn->intcon(middle->_con)); + igvn->replace_input_of(iff2, 1, newbool); - progress_guard.commit(); - return true; + return true; // Success with CmpU } // Merge the branches that trap for this If and the dominating If into diff --git a/src/hotspot/share/opto/subnode.hpp b/src/hotspot/share/opto/subnode.hpp index 29ec25b41f8..358508248d0 100644 --- a/src/hotspot/share/opto/subnode.hpp +++ b/src/hotspot/share/opto/subnode.hpp @@ -334,8 +334,11 @@ struct BoolTest { static mask negate_mask(mask btm) { return mask(btm ^ 4); } static mask unsigned_mask(mask btm); bool is_canonical( ) const { return (_test == BoolTest::ne || _test == BoolTest::lt || _test == BoolTest::le || _test == BoolTest::overflow); } - bool is_less( ) const { return _test == BoolTest::lt || _test == BoolTest::le; } - bool is_greater( ) const { return _test == BoolTest::gt || _test == BoolTest::ge; } + bool is_less( ) const { return is_less(_test); } + bool is_greater( ) const { return is_greater(_test); } + static bool is_less(mask btm) { return btm == BoolTest::lt || btm == BoolTest::le; } + static bool is_greater(mask btm) { return btm == BoolTest::gt || btm == BoolTest::ge; } + void dump_on(outputStream *st) const; mask merge(BoolTest other) const; }; diff --git a/test/hotspot/jtreg/compiler/rangechecks/TestFoldCompares.java b/test/hotspot/jtreg/compiler/rangechecks/TestFoldCompares.java new file mode 100644 index 00000000000..bec3e442403 --- /dev/null +++ b/test/hotspot/jtreg/compiler/rangechecks/TestFoldCompares.java @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test id=vanilla + * @bug 8346420 + * @summary Test logic in IfNode::fold_compares, which folds 2 signed comparisons + * into a single comparison. + * @library /test/lib / + * @run main ${test.main.class} + */ + +/* + * @test id=Xcomp + * @bug 8346420 + * @library /test/lib / + * @run main ${test.main.class} -Xcomp -XX:-TieredCompilation -XX:CompileCommand=compileonly,${test.main.class}::test* + */ + +package compiler.rangechecks; + +import compiler.lib.ir_framework.*; + +/** + * This test here is here to cover some basic cases of IfNode::fold_compares. It also contains the + * reproducers for JDK-8346420. We don't do any result verification, other than that we should never + * hit an Exception. For a test with result verification, see TestFoldComparesFuzzer.java + */ +public class TestFoldCompares { + public static boolean FLAG_FALSE = false; + + public static void main(String[] args) { + TestFramework framework = new TestFramework(); + framework.addFlags(args); + framework.start(); + } + + // ------------------------- Failing cases for JDK-8346420 ------------------------------ + + @Test + @Arguments(values = {Argument.NUMBER_42}) + // Reported overflow case with wrong result in JDK-8346420 + public static void test_Case3a_LTLE_overflow(int i) { + int minimum, maximum; + if (FLAG_FALSE) { + minimum = 0; + maximum = 1; + } else { + // Always goes to else-path + minimum = Integer.MIN_VALUE; + maximum = Integer.MAX_VALUE; + } + // i < INT_MIN || i > MAX_INT + // 42 < INT_MIN || 42 > MAX_INT + // false false + // => false + // + // C2 transforms this into: + // i - minimum >=u (maximum - minimum) + 1 + // 42 - INT_MIN >=u (INT_MAX - INT_MIN) + 1 + // 42 + MIN_INT >=u -1 + 1 + // ------ overflow ------- + // 42 + MIN_INT >=u 0 + // => true + if (i < minimum || i > maximum) { + throw new RuntimeException("i can never be outside [min_int, max_int]"); + } + } + + @Test + @Arguments(values = {Argument.NUMBER_42}) + // Same as test_Case3a_LTLE_overflow, just with swapped conditions (JDK-8346420). + public static void test_Case3b_LTLE_overflow(int i) { + int minimum, maximum; + if (FLAG_FALSE) { + minimum = 0; + maximum = 1; + } else { + // Always goes to else-path + minimum = Integer.MIN_VALUE; + maximum = Integer.MAX_VALUE; + } + if (i > maximum || i < minimum) { + throw new RuntimeException("i can never be outside [min_int, max_int]"); + } + } + + @Test + @Arguments(values = {Argument.NUMBER_42}) + // 22 ConI === 0 [[ 25 37 ]] #int:0 + // 35 ConI === 0 [[ 37 ]] #int:minint + // 33 ConI === 0 [[ 38 81 ]] #int:1 + // 37 Phi === 34 35 22 [[ 42 80 81 84 ]] #int:minint..0, 0u..maxint+1 + // 81 AddI === _ 37 33 [[ 82 ]] + // 82 Node === 81 [[ ]] <----- hook + // + // We hit this assert, found while working on JDK-8346420: + // "fatal error: no reachable node should have no use" + // + // Because we compute: + // lo = lo + 1 + // hook = Node(lo) + // adjusted_val = i - lo + // -> gvn transformed to: (i - lo) + -1 + // -> the "lo = lo + 1" AddI now is only used by the hook, + // but once the hook is destroyed, it has no use any more, + // and we hit the assert. + public static void test_Case4a_LELE_assert(int i) { + int minimum, maximum; + if (FLAG_FALSE) { + minimum = 0; + maximum = 1; + } else { + minimum = Integer.MIN_VALUE; + maximum = Integer.MAX_VALUE; + } + if (i <= minimum || i > maximum) { + throw new RuntimeException("should never be reached"); + } + } + + // ------------------- IR tests to check that optimization was performed ------------------------ + + // The following tests with constant bounds are expected to fold to a single CmpU. + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_lohi_ltle(int i) { + if (i < -100_000 || i > 100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_lohi_lele(int i) { + if (i <= -100_000 || i > 100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_lohi_ltlt(int i) { + if (i < -100_000 || i >= 100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_lohi_lelt(int i) { + if (i <= -100_000 || i >= 100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_hilo_ltle(int i) { + if (i >= 100_000 || i <= -100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_hilo_lele(int i) { + if (i > 100_000 || i <= -100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_hilo_lelt(int i) { + if (i > 100_000 || i < -100_000) { + throw new RuntimeException(); + } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}) + @Arguments(values = {Argument.NUMBER_42}) + public static void test_hilo_ltlt(int i) { + if (i >= 100_000 || i < -100_000) { + throw new RuntimeException(); + } + } + + // The following tests can completely remove the test and branches, we can prove that + // the path cannot be taken. + + @Setup + public static Object[] range256(SetupInfo info) { + return new Object[]{info.invocationCounter() & 255}; + } + + @Setup + public static Object[] rangeM128P127(SetupInfo info) { + return new Object[]{(info.invocationCounter() & 255) - 128}; + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 0"}) + @Arguments(setup = "rangeM128P127") + // Case from JDK-8135069. We used to do the CmpI->CmpU trick, but we can also constant fold + // this directly! + public static void test_empty_0(int i) { + if (i < 0 || i > -1) { + return; // always success + } + throw new RuntimeException("should not be reached"); + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 0"}) + @Arguments(setup = "range256") + public static void test_empty_1(int i) { + if (i < 100 || i > 50) { + return; // always success + } + throw new RuntimeException("should not be reached"); + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 0"}) + @Arguments(setup = "range256") + public static void test_empty_2(int i) { + if (i <= 100 || i >= 101) { + return; // always success + } + throw new RuntimeException("should not be reached"); + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 1", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + // Note: the two CmpI->Bool pairs are already canonicallized and commoned to a single pair. + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 0"}) + @Arguments(setup = "range256") + public static void test_empty_3(int i) { + if (i <= 100 || i > 100) { + return; // always success + } + throw new RuntimeException("should not be reached"); + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 1", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + // Note: the two CmpI->Bool pairs are already canonicallized and commoned to a single pair. + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 0"}) + @Arguments(setup = "range256") + public static void test_empty_4(int i) { + if (i < 101 || i >= 101) { + return; // always success + } + throw new RuntimeException("should not be reached"); + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 0"}) + @Arguments(setup = "range256") + public static void test_empty_5(int i) { + if (i < 101 || i > 100) { + return; // always success + } + throw new RuntimeException("should not be reached"); + } + + // Now test that we can use a.length, which means we do a null-check + // and then a comparison with a LoadRange that has type int[>=0] + + public static int[] ARR = new int[256]; + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING, + applyIf = {"TieredCompilation", "true"}) // proxy for "not Xcomp" + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}, + applyIf = {"TieredCompilation", "true"}) // proxy for "not Xcomp" + @Arguments(setup = "range256") + // Note: cannot get optimized with Xcomp + static int test_array_length_and_null_check_1(int i) { + if (i < 0 || i >= ARR.length) { + return -1; // never happens + } + return i; + } + + @Check(test = "test_array_length_and_null_check_1") + public void check_test_array_length_and_null_check_1(int i) { + if (i < 0) { throw new RuntimeException("Wrong value: " + i); } + } + + @Test + @IR(counts = {IRNode.CMP_I, "= 2", IRNode.CMP_U, "= 0"}, phase = CompilePhase.AFTER_PARSING, + applyIf = {"TieredCompilation", "true"}) // proxy for "not Xcomp" + @IR(counts = {IRNode.CMP_I, "= 0", IRNode.CMP_U, "= 1"}, + applyIf = {"TieredCompilation", "true"}) // proxy for "not Xcomp" + @Arguments(setup = "range256") + // Note: cannot get optimized with Xcomp + static int test_array_length_and_null_check_2(int i) { + if (i < 0 || i >= ARR.length) { + throw new RuntimeException("never go out of bounds"); + } + return i; + } +} diff --git a/test/hotspot/jtreg/compiler/rangechecks/TestFoldComparesFuzzer.java b/test/hotspot/jtreg/compiler/rangechecks/TestFoldComparesFuzzer.java new file mode 100644 index 00000000000..0467689eb17 --- /dev/null +++ b/test/hotspot/jtreg/compiler/rangechecks/TestFoldComparesFuzzer.java @@ -0,0 +1,728 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + + +/* + * @test + * @bug 8346420 + * @summary Fuzz patterns for IfNode::fold_compares_helper + * @modules java.base/jdk.internal.misc + * @library /test/lib / + * @compile ../lib/ir_framework/TestFramework.java + * @compile ../lib/generators/Generators.java + * @compile ../lib/verify/Verify.java + * @run driver ${test.main.class} + */ + +package compiler.rangechecks; + +import java.util.List; +import java.util.ArrayList; +import java.util.Random; +import java.util.HashSet; +import java.util.Set; + +import jdk.test.lib.Utils; + +import compiler.lib.compile_framework.*; +import compiler.lib.generators.*; +import compiler.lib.template_framework.Template; +import compiler.lib.template_framework.TemplateToken; +import static compiler.lib.template_framework.Template.scope; +import static compiler.lib.template_framework.Template.let; +import static compiler.lib.template_framework.Template.$; + +import compiler.lib.template_framework.library.TestFrameworkClass; + +/** + * For more basic examples, see TestFoldCompares.java + * + * I'm only covering some basic cases to test the fundamental + * logic inside IfNode::fold_compares_helper. + * - TestMethodGeneratorConstIR does extensive result and IR verification + * for the cases a-d) in IfNode::fold_compares_helper, but only with + * constant lo and hi. + * - Other test generators currently don't have IR rules, but check + * correctness in various relevant scenarios I came across during + * the bugfix of JDK-8346420. + * - I'm also mixing signed and unsigned comparisons, just to ensure + * the less often used (and tested) unsigned comparisons don't slip + * through the cracks. + * + * In the future, we could add more cases: + * - Extend to long - though the optimization does not yet cover longs anyway. + * - More IR rules: difficult to make stable. Not all permutations are covered + * by the optimizations, edge-cases could make IR rules brittle. + */ +public class TestFoldComparesFuzzer { + private static final Random RANDOM = Utils.getRandomInstance(); + private static final RestrictableGenerator INT_GEN = Generators.G.ints(); + + public static void main(String[] args) { + // Create a new CompileFramework instance. + CompileFramework comp = new CompileFramework(); + + long t0 = System.nanoTime(); + // Add a java source file. + comp.addJavaSourceCode("compiler.rangecheck.templated.Generated", generate(comp)); + + long t1 = System.nanoTime(); + // Compile the source file. + comp.compile(); + + long t2 = System.nanoTime(); + + // Run the tests without any additional VM flags. + comp.invoke("compiler.rangecheck.templated.Generated", "main", new Object[] {new String[] {}}); + long t3 = System.nanoTime(); + + System.out.println("Code Generation: " + (t1-t0) * 1e-9f); + System.out.println("Code Compilation: " + (t2-t1) * 1e-9f); + System.out.println("Running Tests: " + (t3-t2) * 1e-9f); + } + + public static String generate(CompileFramework comp) { + // Create a list to collect all tests. + List testTemplateTokens = new ArrayList<>(); + + for (int i = 0; i < 100; i++) { + testTemplateTokens.add(generateTest(/* no warmup, like -Xcomp */ 0)); + } + for (int i = 0; i < 5; i++) { + testTemplateTokens.add(generateTest(/* with warmup, slower */ 10_000)); + } + + // Create the test class, which runs all testTemplateTokens. + return TestFrameworkClass.render( + // package and class name. + "compiler.rangecheck.templated", "Generated", + // List of imports. + Set.of("compiler.lib.generators.*", + "compiler.lib.verify.*", + "java.util.Random", + "jdk.test.lib.Utils"), + // classpath, so the Test VM has access to the compiled class files. + comp.getEscapedClassPathOfCompiledClasses(), + // The list of tests. + testTemplateTokens); + } + + enum Comparator { + // TODO: enable again after JDK-8385157 + // ULT(" < 0", false), + // ULE(" <= 0", false), + // UGT(" > 0", false), + // UGE(" >= 0", false), + // UEQ(" == 0", false), + // UNE(" != 0", false), + LT(" < ", true), + LE(" <= ", true), + GT(" > ", true), + GE(" >= ", true), + EQ(" == ", true), + NE(" != ", true); + + private final String token; + private final boolean signed; + + Comparator(String token, boolean signed) { + this.token = token; + this.signed = signed; + } + + public String getToken() { + return token; + } + + public boolean isSigned() { + return signed; + } + + public Comparator negate() { + return switch(this) { + // TODO: enable again after JDK-8385157 + // case ULT -> UGE; + // case ULE -> UGT; + // case UGT -> ULE; + // case UGE -> ULT; + // case UEQ -> UNE; + // case UNE -> UEQ; + case LT -> GE; + case LE -> GT; + case GT -> LE; + case GE -> LT; + case EQ -> NE; + case NE -> EQ; + }; + } + + public Comparator flip() { + return switch(this) { + // TODO: enable again after JDK-8385157 + // case ULT -> UGT; + // case ULE -> UGE; + // case UGT -> ULT; + // case UGE -> ULE; + // case UEQ -> UEQ; + // case UNE -> UNE; + case LT -> GT; + case LE -> GE; + case GT -> LT; + case GE -> LE; + case EQ -> EQ; + case NE -> NE; + }; + } + + static Comparator random() { + return values()[RANDOM.nextInt(values().length)]; + } + + static Comparator randomGreater() { + return RANDOM.nextBoolean() ? GE : GT; + } + + static Comparator randomLess() { + return RANDOM.nextBoolean() ? LE : LT; + } + } + + record Comparison(String lhs, Comparator cmp, String rhs, boolean negated) { + public Comparison(String lhs, Comparator cmp, String rhs) { + this(lhs, cmp, rhs, false); + } + + public String toString() { + return cmp.isSigned() + ? ((negated ? "!" : "") + "(" + lhs + " "+ cmp.getToken() + " " + rhs + ")") + : ((negated ? "!" : "") + "(Integer.compareUnsigned(" + lhs + ", " + rhs + ")" + cmp.getToken() + ")"); + } + + // Keep the same semantics of the test, but change its form. + Comparison permuteRandom() { + return flipRandom().complementRandom(); + } + + Comparison flipRandom() { + return RANDOM.nextBoolean() ? this : new Comparison(rhs, cmp.flip(), lhs); + } + + Comparison complementRandom() { + return RANDOM.nextBoolean() ? this : new Comparison(lhs, cmp.negate(), rhs, true); + } + + Comparison negateCmp() { + return new Comparison(lhs, cmp.negate(), rhs, negated); + } + } + + interface TestMethodGenerator { + Template.OneArg getTestTemplate(); + + default Template.ZeroArgs getIRTemplate(boolean withWarmup) { + return Template.make(() -> scope("// No IR rule.\n")); + } + + default Template.ZeroArgs getInputTemplate() { + return Template.make(() -> scope( + """ + RestrictableGenerator gen = Generators.G.ints(); + int n = gen.next(); + int a = gen.next(); + int b = gen.next(); + """ + )); + }; + } + + // Some basic ranges with constant bounds. + // This should test some basic correctness, and also covers the case + // of bug JDK-8135069. + static class TestMethodGeneratorConst implements TestMethodGenerator { + private final int con1 = INT_GEN.next(); + private final int con2 = INT_GEN.next(); + + private final Comparison c1 = new Comparison("n", Comparator.random(), "con1").permuteRandom(); + private final Comparison c2 = new Comparison("n", Comparator.random(), "con2").permuteRandom(); + + private final Template.OneArg testTemplate = Template.make("methodName", (String methodName) -> scope( + let("con1", con1), + let("con2", con2), + let("c1", c1), + let("c2", c2), + """ + static boolean #methodName(int n, int a, int b) { + int con1 = #con1; + int con2 = #con2; + if (#c1 || #c2) { + return true; + } + return false; + } + """ + )); + + public Template.OneArg getTestTemplate() { return testTemplate; } + } + + // Cases where a and b are ranges that touch min_int/max_int. + // Note: if con1=0 and con2=1 then this is like the cases: + // - test_Case3a_LTLE_overflow + // - test_Case3b_LTLE_overflow + // - test_Case4a_LELE_assert + // + // Hence, I think this test gives us quite good coverage for the kinds of bugs + // such as JDK-8346420. + static class TestMethodGeneratorWithIf implements TestMethodGenerator { + private final int con1 = INT_GEN.next(); + private final int con2 = INT_GEN.next(); + private final String m1 = RANDOM.nextBoolean() ? "Integer.MIN_VALUE" : "Integer.MAX_VALUE"; + private final String m2 = RANDOM.nextBoolean() ? "Integer.MIN_VALUE" : "Integer.MAX_VALUE"; + + private final Comparison c1 = new Comparison("n", Comparator.random(), "a").permuteRandom(); + private final Comparison c2 = new Comparison("n", Comparator.random(), "b").permuteRandom(); + + private final Template.OneArg testTemplate = Template.make("methodName", (String methodName) -> scope( + let("con1", con1), + let("con2", con2), + let("m1", m1), + let("m2", m2), + let("c1", c1), + let("c2", c2), + """ + static boolean #methodName(int n, int a, int b) { + if (a < b) { + a = #con1; + b = #con2; + } else { + a = #m1; + b = #m2; + } + if (#c1 || #c2) { + return true; + } + return false; + } + """ + )); + + public Template.OneArg getTestTemplate() { return testTemplate; } + } + + // Just for good practice: add some case where the ranges are more free. + static class TestMethodGeneratorRanges implements TestMethodGenerator { + private final int n_hi = INT_GEN.next(); + private final int n_lo = INT_GEN.next(); + private final int a_hi = INT_GEN.next(); + private final int a_lo = INT_GEN.next(); + private final int b_hi = INT_GEN.next(); + private final int b_lo = INT_GEN.next(); + + private final Comparison c1 = new Comparison("n", Comparator.random(), "a").permuteRandom(); + private final Comparison c2 = new Comparison("n", Comparator.random(), "b").permuteRandom(); + + private final Template.OneArg template = Template.make("methodName", (String methodName) -> scope( + let("n_hi", n_hi), + let("n_lo", n_lo), + let("a_hi", a_hi), + let("a_lo", a_lo), + let("b_hi", b_hi), + let("b_lo", b_lo), + let("c1", c1), + let("c2", c2), + """ + static boolean #methodName(int n, int a, int b) { + n = Math.min(#n_hi, Math.max(#n_lo, n)); + a = Math.min(#a_hi, Math.max(#a_lo, a)); + b = Math.min(#b_hi, Math.max(#b_lo, b)); + if (#c1 || #c2) { + return true; + } + return false; + } + """ + )); + + public Template.OneArg getTestTemplate() { + return template; + } + } + + // Generate some more constrained cases, but with IR rules + static class TestMethodGeneratorConstIR implements TestMethodGenerator { + private final int lo; + private final int hi; + { // instance initializer + // We want to cover all cases for lo and hi combinations. But the + // critical cases happen around int_min and int_max, and when + // lo and hi are close to each other. + switch (RANDOM.nextInt(3)) { + case 0 -> { + // Full freedom, will eventually cover all cases + lo = INT_GEN.next(); + hi = INT_GEN.next(); + } + case 1 -> { + // Pick cases around overflow and underflow + lo = Integer.MAX_VALUE - 5 + RANDOM.nextInt(10); + hi = Integer.MAX_VALUE - 5 + RANDOM.nextInt(10); + } + default -> { + // Pick cases where lo and hi are close to each other + lo = INT_GEN.next(); + hi = lo - 5 + RANDOM.nextInt(10); + } + } + } + + // Since we are using constants for lo and hi, the checks should get canonicalized, + // so that n is always in the lhs. We only create cases that are covered by the + // 4 cases of "2 CmpI -> 1 CmpU" optimization in IfNode::fold_compares_helper. + private final Comparison c_lo = new Comparison("n", Comparator.randomGreater(), "lo"); + private final Comparison c_hi = new Comparison("n", Comparator.randomLess(), "hi"); + private final boolean swap = RANDOM.nextBoolean(); + private final Comparison c1Permuted = (swap ? c_lo : c_hi).permuteRandom(); + private final Comparison c2Permuted = (swap ? c_hi : c_lo).permuteRandom(); + // n > lo && n < hi -> check for inside range + // n <= lo || n >= hi -> chedk for outside range + private final boolean withAnd = RANDOM.nextBoolean(); + private final String operator = withAnd ? "&&" : "||"; + private final Comparison c1 = withAnd ? c1Permuted : c1Permuted.negateCmp(); + private final Comparison c2 = withAnd ? c2Permuted : c2Permuted.negateCmp(); + + private final Template.OneArg testTemplate = Template.make("methodName", (String methodName) -> scope( + let("lo", lo), + let("hi", hi), + let("c1", c1), + let("c2", c2), + let("op", operator), + """ + static boolean #methodName(int n, int a, int b) { + int lo = #lo; + int hi = #hi; + if (#c1 #op #c2) { + return true; + } + return false; + } + """ + )); + + public Template.OneArg getTestTemplate() { return testTemplate; } + + public Template.ZeroArgs getIRTemplate(boolean withWarmup) { + return Template.make(() -> { + String cmpIParse, cmpUParse, cmpIFinal, cmpUFinal; + String comment; + + // If both branches are compiled (in -Xcomp mode, i.e. no warmup), then + // we can know very precisely what happens in each case. + if (c_lo.cmp() == Comparator.GT && c_hi.cmp() == Comparator.LT) { + // a) (n > lo && n < hi) + if (lo == Integer.MAX_VALUE || hi == Integer.MIN_VALUE) { + cmpIParse = "< 2"; cmpUParse = "= 0"; cmpIFinal = "< 2"; cmpUFinal = "= 0"; + comment = "a) one or both checks fold at parse time"; + } else if (lo < hi && lo+2 == hi) { + // Not yet folded at parsing, because lo != hi + // BoolNode::Ideal: x x==0 (signed) + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 1"; cmpUFinal = "= 0"; + comment = "a) replace with CmpU (single element) -> CmpI eq"; + } else if (lo < hi && lo+1 == hi) { + // Not yet folded at parsing, because lo != hi + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "a) impossible condition (exact) -> fold away"; + } else if (lo < hi) { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 1"; + comment = "a) replace with CmpU (non-empty)"; + } else if (lo == hi) { + // same CmpI at parse time + cmpIParse = "= 1"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "a) impossible condition -> fold away"; + } else { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "a) impossible condition -> fold away"; + } + } else if (c_lo.cmp() == Comparator.GT && c_hi.cmp() == Comparator.LE) { + // b) (n > lo && n <= hi) + if (lo == Integer.MAX_VALUE || hi == Integer.MAX_VALUE) { + cmpIParse = "< 2"; cmpUParse = "= 0"; cmpIFinal = "< 2"; cmpUFinal = "= 0"; + comment = "b) one or both checks fold at parse time"; + } else if (lo < hi && lo+1 == hi) { + // BoolNode::Ideal: x x==0 (signed) + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 1"; cmpUFinal = "= 0"; + comment = "b) replace with CmpU (single element) -> CmpI eq"; + } else if (lo < hi && lo+1 < hi) { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 1"; + comment = "b) replace with CmpU (non-empty)"; + } else if (lo == hi) { + cmpIParse = "= 1"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "b) impossible condition (exact) -> fold away"; + } else { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "b) impossible condition -> fold away"; + } + } else if (c_lo.cmp() == Comparator.GE && c_hi.cmp() == Comparator.LT) { + // c) (n >= lo && n < hi) + if (lo == Integer.MIN_VALUE || hi == Integer.MIN_VALUE) { + cmpIParse = "< 2"; cmpUParse = "= 0"; cmpIFinal = "< 2"; cmpUFinal = "= 0"; + comment = "c) one or both checks fold at parse time"; + } else if (lo < hi && lo+1 == hi) { + // BoolNode::Ideal: x x==0 (signed) + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 1"; cmpUFinal = "= 0"; + comment = "c) replace with CmpU (single element) -> CmpI eq"; + } else if (lo < hi && lo+1 < hi) { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 1"; + comment = "c) replace with CmpU (non-empty)"; + } else if (lo == hi) { + // RegionNode::optimize_trichotomy: can fold (n >= x && n < x) -> never + cmpIParse = "< 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "c) impossible condition (exact) -> fold away"; + } else { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "c) impossible condition -> fold away"; + } + } else if (c_lo.cmp() == Comparator.GE && c_hi.cmp() == Comparator.LE) { + // d) (n >= lo && n <= hi) + if (lo == Integer.MIN_VALUE || hi == Integer.MAX_VALUE) { + cmpIParse = "< 2"; cmpUParse = "= 0"; cmpIFinal = "< 2"; cmpUFinal = "= 0"; + comment = "d) one or both checks fold at parse time"; + } else if (lo == hi) { + // same CmpI at parse time + // BoolNode::Ideal: x x==0 (signed) + cmpIParse = "= 1"; cmpUParse = "= 0"; cmpIFinal = "= 1"; cmpUFinal = "= 0"; + comment = "d) replace with CmpU (single element) -> CmpI eq"; + } else if (lo < hi) { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 1"; + comment = "d) replace with CmpU (non-empty)"; + } else { + cmpIParse = "= 2"; cmpUParse = "= 0"; cmpIFinal = "= 0"; cmpUFinal = "= 0"; + comment = "d) impossible condition -> fold away"; + } + } else { + throw new RuntimeException("should not be generated: " + c_lo + " and " + c_hi); + } + + // All the precise counting above assumes that both ifs get compiled, and hence + // both CmpI are generated. Further, it assumes that both of the "or" branches + // (fail1 and fail2) end up "in the same place": either at the same region, or + // both in an uncommon trap. With profiling, the following cases are possible: + // - The first if is constant folded to fail1, and we have no CmpI nor CmpU + // in the graph. + // - The first if always leads to fail1, and away from the second if, and so we + // only have a single CmpI in the graph after parsing. + // - The first if always leads towards the second if, and away from fail1. And + // the second if always points towards fail2 and away from succ. We get an + // uncommon trap for fail1 and succ, and only the fail2 path is compiled. + // Hence, we have two CmpI, but fail1 and fail2 do not end up "in the same place". + // This makes our IR rule quite weak, sadly. We could make the IR rules stronger, + // but we would need to control warmup, and generate corresponding inputs that + // ensure the right paths are compiled or not compiled. + if (withWarmup) { + cmpIParse = "<= 2"; cmpUParse = "= 0"; cmpIFinal = "<= 2"; cmpUFinal = "< 2"; + comment = "with warmup: unstable-if makes precise counting hard."; + } + + return scope( + let("IP", cmpIParse), + let("UP", cmpUParse), + let("IF", cmpIFinal), + let("UF", cmpUFinal), + let("comment", comment), + """ + // #comment + @IR(counts = {IRNode.CMP_I, "#IP", IRNode.CMP_U, "#UP"}, phase = CompilePhase.AFTER_PARSING) + @IR(counts = {IRNode.CMP_I, "#IF", IRNode.CMP_U, "#UF"}) + """ + ); + }); + } + + @Override + public Template.ZeroArgs getInputTemplate() { + return Template.make(() -> scope( + let("lo", lo), + let("hi", hi), + """ + Random r = Utils.getRandomInstance(); + RestrictableGenerator gen = Generators.G.ints(); + int a = gen.next(); + int b = gen.next(); + """, + switch (RANDOM.nextInt(9)) { + // Random values + case 0 -> "int n = gen.next();\n"; + // Fuzz around specific values + case 1 -> "int n = r.nextInt(10) - 5 + #lo;\n"; + case 2 -> "int n = r.nextInt(10) - 5 + #hi;\n"; + case 3 -> "int n = r.nextInt(10) - 5 + (r.nextBoolean() ? #lo : #hi);\n"; + case 4 -> "int n = r.nextInt(10) - 5 + Integer.MAX_VALUE;\n"; + // Only very low or very high values, or in the middle + case 5 -> "int n = r.nextInt(10) - 10 + Integer.MAX_VALUE;\n"; + case 6 -> "int n = r.nextInt(10) + Integer.MIN_VALUE;\n"; + case 7 -> "int n = r.nextInt(10) - 5 + #lo/2 + #hi/2;\n"; + // Always the same constant + default -> "int n = " + INT_GEN.next() + ";\n"; + } + )); + }; + } + + // switch cases can also be implemented with range checks using + // constants, and then we can optimize 2 CmpI with a single CmpU, + // at least in some cases. + static class TestMethodGeneratorSwitch implements TestMethodGenerator { + Set cases = new HashSet<>(); + { // instance initializer + int n = RANDOM.nextInt(1, 20); + for (int i = 0; i < n; i++) { + cases.add((short)(int)INT_GEN.next()); + } + } + + private final Template.OneArg testTemplate = Template.make("methodName", (String methodName) -> scope( + """ + static boolean #methodName(int n, int a, int b) { + switch((short)n) { + """, + cases.stream().map(i -> scope( + let("i", i), + """ + case (short)#i: + """ + )).toList(), + """ + return true; + default: + return false; + } + } + """ + )); + + public Template.OneArg getTestTemplate() { return testTemplate; } + } + + // If arr.length is in the second check, the null-check for arr + // is located between the two checks. + // I'm not adding any IR rules here, just checking for correctness. + static class TestMethodGeneratorArrLength implements TestMethodGenerator { + private final int n_hi = INT_GEN.next(); + private final int n_lo = INT_GEN.next(); + private final int a_hi = INT_GEN.next(); + private final int a_lo = INT_GEN.next(); + private final int size = INT_GEN.restricted(0, 100_000).next(); + + // Get checks like: n < a || n >= arr.length + private final Comparison c_lo = new Comparison("n", Comparator.random(), "a").permuteRandom(); + private final Comparison c_hi = new Comparison("n", Comparator.random(), "arr.length").permuteRandom(); + private final boolean swap = RANDOM.nextBoolean(); + private final Comparison c1Permuted = (swap ? c_lo : c_hi).permuteRandom(); + private final Comparison c2Permuted = (swap ? c_hi : c_lo).permuteRandom(); + // n > lo && n < hi -> check for inside range + // n <= lo || n >= hi -> chedk for outside range + private final boolean withAnd = RANDOM.nextBoolean(); + private final String operator = withAnd ? "&&" : "||"; + private final Comparison c1 = withAnd ? c1Permuted : c1Permuted.negateCmp(); + private final Comparison c2 = withAnd ? c2Permuted : c2Permuted.negateCmp(); + + private final Template.OneArg testTemplate = Template.make("methodName", (String methodName) -> scope( + let("n_hi", n_hi), + let("n_lo", n_lo), + let("a_hi", a_hi), + let("a_lo", a_lo), + let("size", size), + let("c1", c1), + let("c2", c2), + let("op", operator), + """ + static boolean #methodName(int n, int a, int b) { + int[] arr = $arr; + n = Math.min(#n_hi, Math.max(#n_lo, n)); + a = Math.min(#a_hi, Math.max(#a_lo, a)); + if (#c1 #op #c2) { + return true; + } + return false; + } + static int[] $arr = new int[#size]; + """ + )); + + public Template.OneArg getTestTemplate() { return testTemplate; } + } + + public static TemplateToken generateTest(int warmup) { + TestMethodGenerator tg = switch(RANDOM.nextInt(6)) { + case 0 -> new TestMethodGeneratorConst(); + case 1 -> new TestMethodGeneratorWithIf(); + case 2 -> new TestMethodGeneratorRanges(); + case 3 -> new TestMethodGeneratorConstIR(); + case 4 -> new TestMethodGeneratorSwitch(); + case 5 -> new TestMethodGeneratorArrLength(); + default -> throw new RuntimeException("not expected"); + }; + Template.ZeroArgs testInputTemplate = tg.getInputTemplate(); + Template.OneArg testMethodTemplate = tg.getTestTemplate(); + Template.ZeroArgs testIRTemplate = tg.getIRTemplate(warmup >= 10_000); + + var testTemplate = Template.make(() -> scope( + let("warmup", warmup / 100), + """ + // --- $test start --- + @Run(test = "$test") + @Warmup(#warmup) + public static void $run() { + for (int i = 0; i < 100; i++) { + // Generate random values for n, a, b. + """, + testInputTemplate.asToken(), + """ + + // Run test and compare with interpreter results. + var result = $test(n, a, b); + var expected = $reference(n, a, b); + if (result != expected) { + throw new RuntimeException("wrong result: " + result + " vs " + expected + + "\\nn: " + n + + "\\na: " + a + + "\\nb: " + b); + } + } + } + + @Test + """, + testIRTemplate.asToken(), + testMethodTemplate.asToken($("test")), + """ + + @DontCompile + """, + testMethodTemplate.asToken($("reference")), + """ + // --- $test end --- + """ + )); + return testTemplate.asToken(); + } +}