8347555: [REDO] C2: implement optimization for series of Add of unique value

Reviewed-by: roland, epeter
This commit is contained in:
Kangcheng Xu 2025-10-10 14:04:51 +00:00 committed by Roland Westrelin
parent 5594d6bc88
commit f6d77cb332
6 changed files with 875 additions and 1 deletions

View File

@ -396,9 +396,182 @@ Node* AddNode::IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt) {
}
}
// Collapse addition of the same terms into multiplications.
Node* collapsed = Ideal_collapse_variable_times_con(phase, bt);
if (collapsed != nullptr) {
return collapsed; // Skip AddNode::Ideal() since it may now be a multiplication node.
}
return AddNode::Ideal(phase, can_reshape);
}
// Try to collapse addition of the same terms into a single multiplication. On success, a new MulNode is returned.
// Examples of this conversion includes:
// - a + a + ... + a => CON*a
// - (a * CON) + a => (CON + 1) * a
// - a + (a * CON) => (CON + 1) * a
//
// We perform such conversions incrementally during IGVN by transforming left most nodes first and work up to the root
// of the expression. In other words, we convert, at each iteration:
// a + a + a + ... + a
// => 2*a + a + ... + a
// => 3*a + ... + a
// => n*a
//
// Due to the iterative nature of IGVN, MulNode transformed from first few AddNode terms may be further transformed into
// power-of-2 pattern. (e.g., 2 * a => a << 1, 3 * a => (a << 2) + a). We can't guarantee we'll always pick up
// transformed power-of-2 patterns when term `a` is complex.
//
// Note this also converts, for example, original expression `(a*3) + a` into `4*a` and `(a<<2) + a` into `5*a`. A more
// generalized pattern `(a*b) + (a*c)` into `a*(b + c)` is handled by AddNode::IdealIL().
Node* AddNode::Ideal_collapse_variable_times_con(PhaseGVN* phase, BasicType bt) {
// We need to make sure that the current AddNode is not part of a MulNode that has already been optimized to a
// power-of-2 addition (e.g., 3 * a => (a << 2) + a). Without this check, GVN would keep trying to optimize the same
// node and can't progress. For example, 3 * a => (a << 2) + a => 3 * a => (a << 2) + a => ...
if (Multiplication::find_power_of_two_addition_pattern(this, bt).is_valid()) {
return nullptr;
}
Node* lhs = in(1);
Node* rhs = in(2);
Multiplication mul = Multiplication::find_collapsible_addition_patterns(lhs, rhs, bt);
if (!mul.is_valid_with(rhs)) {
// Swap lhs and rhs then try again
mul = Multiplication::find_collapsible_addition_patterns(rhs, lhs, bt);
if (!mul.is_valid_with(lhs)) {
return nullptr;
}
}
Node* con;
if (bt == T_INT) {
con = phase->intcon(java_add(static_cast<jint>(mul.multiplier()), 1));
} else {
con = phase->longcon(java_add(mul.multiplier(), CONST64(1)));
}
return MulNode::make(con, mul.variable(), bt);
}
// Find a pattern of collapsable additions that can be converted to a multiplication.
// When matching the LHS `a * CON`, we match with best efforts by looking for the following patterns:
// - (1) Simple addition: LHS = a + a
// - (2) Simple lshift: LHS = a << CON
// - (3) Simple multiplication: LHS = CON * a
// - (4) Power-of-two addition: LHS = (a << CON1) + (a << CON2)
AddNode::Multiplication AddNode::Multiplication::find_collapsible_addition_patterns(const Node* a, const Node* pattern, BasicType bt) {
// (1) Simple addition pattern (e.g., lhs = a + a)
Multiplication mul = find_simple_addition_pattern(a, bt);
if (mul.is_valid_with(pattern)) {
return mul;
}
// (2) Simple lshift pattern (e.g., lhs = a << CON)
mul = find_simple_lshift_pattern(a, bt);
if (mul.is_valid_with(pattern)) {
return mul;
}
// (3) Simple multiplication pattern (e.g., lhs = CON * a)
mul = find_simple_multiplication_pattern(a, bt);
if (mul.is_valid_with(pattern)) {
return mul;
}
// (4) Power-of-two addition pattern (e.g., lhs = (a << CON1) + (a << CON2))
// While multiplications can be potentially optimized to power-of-2 subtractions (e.g., a * 7 => (a << 3) - a),
// (x - y) + y => x is already handled by the Identity() methods. So, we don't need to check for that pattern here.
mul = find_power_of_two_addition_pattern(a, bt);
if (mul.is_valid_with(pattern)) {
return mul;
}
// We've tried everything.
return make_invalid();
}
// Try to match `n = a + a`. On success, return a struct with `.valid = true`, `variable = a`, and `multiplier = 2`.
// The method matches `n` for pattern: a + a.
AddNode::Multiplication AddNode::Multiplication::find_simple_addition_pattern(const Node* n, BasicType bt) {
if (n->Opcode() == Op_Add(bt) && n->in(1) == n->in(2)) {
return Multiplication(n->in(1), 2);
}
return make_invalid();
}
// Try to match `n = a << CON`. On success, return a struct with `.valid = true`, `variable = a`, and
// `multiplier = 1 << CON`.
// Match `n` for pattern: a << CON.
// Note that the power-of-2 multiplication optimization could potentially convert a MulNode to this pattern.
AddNode::Multiplication AddNode::Multiplication::find_simple_lshift_pattern(const Node* n, BasicType bt) {
// Note that power-of-2 multiplication optimization could potentially convert a MulNode to this pattern
if (n->Opcode() == Op_LShift(bt) && n->in(2)->is_Con()) {
Node* con = n->in(2);
if (!con->is_top()) {
return Multiplication(n->in(1), java_shift_left(1, con->get_int(), bt));
}
}
return make_invalid();
}
// Try to match `n = CON * a`. On success, return a struct with `.valid = true`, `variable = a`, and `multiplier = CON`.
// Match `n` for patterns: CON * a
// Note that `CON` will always be the second input node of a Mul node canonicalized by Ideal(). If this is not the case,
// `n` has not been processed by iGVN. So we skip the optimization for the current add node and wait for to be added to
// the queue again.
AddNode::Multiplication AddNode::Multiplication::find_simple_multiplication_pattern(const Node* n, BasicType bt) {
if (n->Opcode() == Op_Mul(bt) && n->in(2)->is_Con()) {
Node* con = n->in(2);
Node* base = n->in(1);
if (!con->is_top()) {
return Multiplication(base, con->get_integer_as_long(bt));
}
}
return make_invalid();
}
// Try to match `n = (a << CON1) + (a << CON2)`. On success, return a struct with `.valid = true`, `variable = a`, and
// `multiplier = (1 << CON1) + (1 << CON2)`.
// Match `n` for patterns:
// - (1) (a << CON) + (a << CON)
// - (2) (a << CON) + a
// - (3) a + (a << CON)
// - (4) a + a
// Note that one or both of the term of the addition could simply be `a` (i.e., a << 0) as in pattern (4).
AddNode::Multiplication AddNode::Multiplication::find_power_of_two_addition_pattern(const Node* n, BasicType bt) {
if (n->Opcode() == Op_Add(bt) && n->in(1) != n->in(2)) {
const Multiplication lhs = find_simple_lshift_pattern(n->in(1), bt);
const Multiplication rhs = find_simple_lshift_pattern(n->in(2), bt);
// Pattern (1)
{
const Multiplication res = lhs.add(rhs);
if (res.is_valid()) {
return res;
}
}
// Pattern (2)
if (lhs.is_valid_with(n->in(2))) {
return Multiplication(lhs.variable(), java_add(lhs.multiplier(), CONST64(1)));
}
// Pattern (3)
if (rhs.is_valid_with(n->in(1))) {
return Multiplication(rhs.variable(), java_add(rhs.multiplier(), CONST64(1)));
}
// Pattern (4), which is equivalent to a simple addition pattern
return find_simple_addition_pattern(n, bt);
}
return make_invalid();
}
Node* AddINode::Ideal(PhaseGVN* phase, bool can_reshape) {
Node* in1 = in(1);

View File

@ -42,7 +42,51 @@ typedef const Pair<Node*, jint> ConstAddOperands;
// by virtual functions.
class AddNode : public Node {
virtual uint hash() const;
public:
class Multiplication {
bool _is_valid = false;
Node* _variable = nullptr;
jlong _multiplier = 0;
private:
Multiplication() {}
public:
Multiplication(Node* variable, jlong multiplier) :
_is_valid(true),
_variable(variable),
_multiplier(multiplier) {}
static Multiplication make_invalid() {
static Multiplication invalid = Multiplication();
return invalid;
}
static Multiplication find_collapsible_addition_patterns(const Node* a, const Node* pattern, BasicType bt);
static Multiplication find_simple_addition_pattern(const Node* n, BasicType bt);
static Multiplication find_simple_lshift_pattern(const Node* n, BasicType bt);
static Multiplication find_simple_multiplication_pattern(const Node* n, BasicType bt);
static Multiplication find_power_of_two_addition_pattern(const Node* n, BasicType bt);
Multiplication add(const Multiplication rhs) const {
if (is_valid_with(rhs.variable()) && rhs.is_valid_with(variable())) {
return Multiplication(variable(), java_add(multiplier(), rhs.multiplier()));
}
return make_invalid();
}
bool is_valid() const { return _is_valid; }
bool is_valid_with(const Node* variable) const {
return _is_valid && this->_variable == variable;
}
Node* variable() const { return _variable; }
jlong multiplier() const { return _multiplier; }
};
public:
AddNode( Node *in1, Node *in2 ) : Node(nullptr,in1,in2) {
init_class_id(Class_Add);
}
@ -55,6 +99,7 @@ public:
// and flatten expressions (so that 1+x+2 becomes x+3).
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
Node* IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt);
Node* Ideal_collapse_variable_times_con(PhaseGVN* phase, BasicType bt);
// Compute a new Type for this node. Basically we just do the pre-check,
// then call the virtual add() to set the type.

View File

@ -1253,6 +1253,24 @@ JAVA_INTEGER_SHIFT_OP(>>, java_shift_right_unsigned, jlong, julong)
#undef JAVA_INTEGER_SHIFT_OP
// Some convenient bit shift operations that accepts a BasicType as the last
// argument. These avoid potential mistakes with overloaded functions only
// distinguished by lhs argument type.
#define JAVA_INTEGER_SHIFT_BASIC_TYPE(FUNC) \
inline jlong FUNC(jlong lhs, jint rhs, BasicType bt) { \
if (bt == T_INT) { \
return FUNC((jint) lhs, rhs); \
} \
assert(bt == T_LONG, "unsupported basic type"); \
return FUNC(lhs, rhs); \
}
JAVA_INTEGER_SHIFT_BASIC_TYPE(java_shift_left)
JAVA_INTEGER_SHIFT_BASIC_TYPE(java_shift_right)
JAVA_INTEGER_SHIFT_BASIC_TYPE(java_shift_right_unsigned)
#undef JAVA_INTERGER_SHIFT_BASIC_TYPE
//----------------------------------------------------------------------------------------------------
// The goal of this code is to provide saturating operations for int/uint.
// Checks overflow conditions and saturates the result to min_jint/max_jint.

View File

@ -192,6 +192,7 @@ TEST(TestJavaArithmetic, shift_left_jint) {
const volatile ShiftOpJintData* data = asl_jint_data;
for (size_t i = 0; i < ARRAY_SIZE(asl_jint_data); ++i) {
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift));
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift, T_INT));
}
}
@ -199,6 +200,7 @@ TEST(TestJavaArithmetic, shift_left_jlong) {
const volatile ShiftOpJlongData* data = asl_jlong_data;
for (size_t i = 0; i < ARRAY_SIZE(asl_jlong_data); ++i) {
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift));
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift, T_LONG));
}
}
@ -262,6 +264,7 @@ TEST(TestJavaArithmetic, shift_right_jint) {
const volatile ShiftOpJintData* data = asr_jint_data;
for (size_t i = 0; i < ARRAY_SIZE(asr_jint_data); ++i) {
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift));
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift, T_INT));
}
}
@ -269,6 +272,7 @@ TEST(TestJavaArithmetic, shift_right_jlong) {
const volatile ShiftOpJlongData* data = asr_jlong_data;
for (size_t i = 0; i < ARRAY_SIZE(asr_jlong_data); ++i) {
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift));
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift, T_LONG));
}
}
@ -334,6 +338,7 @@ TEST(TestJavaArithmetic, shift_right_unsigned_jint) {
const volatile ShiftOpJintData* data = lsr_jint_data;
for (size_t i = 0; i < ARRAY_SIZE(lsr_jint_data); ++i) {
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift));
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift, T_INT));
}
}
@ -341,5 +346,6 @@ TEST(TestJavaArithmetic, shift_right_unsigned_jlong) {
const volatile ShiftOpJlongData* data = lsr_jlong_data;
for (size_t i = 0; i < ARRAY_SIZE(lsr_jlong_data); ++i) {
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift));
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift, T_LONG));
}
}

View File

@ -0,0 +1,430 @@
/*
* Copyright (c) 2025 Red Hat and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package compiler.c2.gvn;
import compiler.lib.generators.Generators;
import compiler.lib.generators.RestrictableGenerator;
import compiler.lib.ir_framework.Test;
import compiler.lib.ir_framework.*;
import jdk.test.lib.Asserts;
import jdk.test.lib.Utils;
import java.util.Random;
/*
* @test
* @bug 8325495 8347555
* @summary C2 should optimize addition of the same terms by collapsing them into one multiplication.
* @library /test/lib /
* @run driver compiler.c2.gvn.TestCollapsingSameTermAdditions
*/
public class TestCollapsingSameTermAdditions {
private static final RestrictableGenerator<Integer> GEN_INT = Generators.G.ints();
private static final RestrictableGenerator<Long> GEN_LONG = Generators.G.longs();
public static void main(String[] args) {
TestFramework.run();
}
@Run(test = {
"addTo2",
"addTo3",
"addTo4",
"shiftAndAddTo4",
"mulAndAddTo4",
"addTo5",
"addTo6",
"addTo7",
"addTo8",
"addTo16",
"addAndShiftTo16",
"addTo42",
"mulAndAddTo42",
"mulAndAddToMax",
"mulAndAddToOverflow",
"mulAndAddToZero",
"mulAndAddToMinus1",
"mulAndAddToMinus42"
})
private void runIntTests() {
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, GEN_INT.next() }) {
Asserts.assertEQ(a * 2, addTo2(a));
Asserts.assertEQ(a * 3, addTo3(a));
Asserts.assertEQ(a * 4, addTo4(a));
Asserts.assertEQ(a * 4, shiftAndAddTo4(a));
Asserts.assertEQ(a * 4, mulAndAddTo4(a));
Asserts.assertEQ(a * 5, addTo5(a));
Asserts.assertEQ(a * 6, addTo6(a));
Asserts.assertEQ(a * 7, addTo7(a));
Asserts.assertEQ(a * 8, addTo8(a));
Asserts.assertEQ(a * 16, addTo16(a));
Asserts.assertEQ(a * 16, addAndShiftTo16(a));
Asserts.assertEQ(a * 42, addTo42(a));
Asserts.assertEQ(a * 42, mulAndAddTo42(a));
Asserts.assertEQ(a * Integer.MAX_VALUE, mulAndAddToMax(a));
Asserts.assertEQ(a * Integer.MIN_VALUE, mulAndAddToOverflow(a));
Asserts.assertEQ(0, mulAndAddToZero(a));
Asserts.assertEQ(a * -1, mulAndAddToMinus1(a));
Asserts.assertEQ(a * -42, mulAndAddToMinus42(a));
}
}
@Run(test = {
"mulAndAddToIntOverflowL",
"mulAndAddToMaxL",
"mulAndAddToOverflowL"
})
private void runLongTests() {
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, GEN_LONG.next() }) {
Asserts.assertEQ(a * (Integer.MAX_VALUE + 1L), mulAndAddToIntOverflowL(a));
Asserts.assertEQ(a * Long.MAX_VALUE, mulAndAddToMaxL(a));
Asserts.assertEQ(a * Long.MIN_VALUE, mulAndAddToOverflowL(a));
}
}
@Run(test = {
"bitShiftToOverflow",
"bitShiftToOverflowL"
})
private void runBitShiftTests() {
Asserts.assertEQ(95, bitShiftToOverflow());
Asserts.assertEQ(191L, bitShiftToOverflowL());
}
// ----- integer tests -----
@Test
@IR(counts = { IRNode.ADD_I, "1" })
@IR(failOn = IRNode.LSHIFT_I)
private static int addTo2(int a) {
return a + a; // Simple additions like a + a should be kept as-is
}
@Test
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
private static int addTo3(int a) {
return a + a + a; // a*3 => (a<<1) + a
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int addTo4(int a) {
return a + a + a + a; // a*4 => a<<2
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int shiftAndAddTo4(int a) {
return (a << 1) + a + a; // a*2 + a + a => a*3 + a => a*4 => a<<2
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int mulAndAddTo4(int a) {
return a * 3 + a; // a*4 => a<<2
}
@Test
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
private static int addTo5(int a) {
return a + a + a + a + a; // a*5 => (a<<2) + a
}
@Test
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "2" })
private static int addTo6(int a) {
return a + a + a + a + a + a; // a*6 => (a<<1) + (a<<2)
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
private static int addTo7(int a) {
return a + a + a + a + a + a + a; // a*7 => (a<<3) - a
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int addTo8(int a) {
return a + a + a + a + a + a + a + a; // a*8 => a<<3
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int addTo16(int a) {
return a + a + a + a + a + a + a + a + a + a
+ a + a + a + a + a + a; // a*16 => a<<4
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int addAndShiftTo16(int a) {
return (a + a) << 3; // a<<(3 + 1) => a<<4
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.MUL_I, "1" })
private static int addTo42(int a) {
return a + a + a + a + a + a + a + a + a + a
+ a + a + a + a + a + a + a + a + a + a
+ a + a + a + a + a + a + a + a + a + a
+ a + a + a + a + a + a + a + a + a + a
+ a + a; // a*42
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.MUL_I, "1" })
private static int mulAndAddTo42(int a) {
return a * 40 + a + a; // a*41 + a => a*42
}
private static final int INT_MAX_MINUS_ONE = Integer.MAX_VALUE - 1;
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
private static int mulAndAddToMax(int a) {
return a * INT_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - a => (a<<31) - a
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int mulAndAddToOverflow(int a) {
return a * Integer.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<31
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.CON_I, "1" })
private static int mulAndAddToZero(int a) {
return a * -1 + a; // 0
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
private static int mulAndAddToMinus1(int a) {
return a * -2 + a; // a*-1 => a - (a<<1)
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.MUL_I, "1" })
private static int mulAndAddToMinus42(int a) {
return a * -43 + a; // a*-42
}
// --- long tests ---
@Test
@IR(failOn = IRNode.ADD_L)
@IR(counts = { IRNode.LSHIFT_L, "1" })
private static long mulAndAddToIntOverflowL(long a) {
return a * Integer.MAX_VALUE + a; // a*(INT_MAX+1)
}
private static final long LONG_MAX_MINUS_ONE = Long.MAX_VALUE - 1;
@Test
@IR(failOn = IRNode.ADD_L)
@IR(counts = { IRNode.LSHIFT_L, "1", IRNode.SUB_L, "1" })
private static long mulAndAddToMaxL(long a) {
return a * LONG_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - 1 => (a<<63) - 1
}
@Test
@IR(failOn = IRNode.ADD_L)
@IR(counts = { IRNode.LSHIFT_L, "1" })
private static long mulAndAddToOverflowL(long a) {
return a * Long.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<63
}
// --- bit shift tests ---
@Test
@IR(failOn = {IRNode.ADD_I, IRNode.LSHIFT_I})
private static int bitShiftToOverflow() {
int i, x = 0;
for (i = 0; i < 32; i++) {
x = i;
}
// x = 31 (phi), i = 32 (phi + 1)
return i + (x << i) + i; // Expects 32 + 31 + 32 = 95
}
@Test
@IR(failOn = {IRNode.ADD_L, IRNode.LSHIFT_L})
private static long bitShiftToOverflowL() {
int i, x = 0;
for (i = 0; i < 64; i++) {
x = i;
}
// x = 63 (phi), i = 64 (phi + 1)
return i + (x << i) + i; // Expects 64 + 63 + 64 = 191
}
// --- random tests ---
private static final int CON1_I, CON2_I, CON3_I, CON4_I;
private static final long CON1_L, CON2_L, CON3_L, CON4_L;
static {
CON1_I = GEN_INT.next();
CON2_I = GEN_INT.next();
CON3_I = GEN_INT.next();
CON4_I = GEN_INT.next();
CON1_L = GEN_LONG.next();
CON2_L = GEN_LONG.next();
CON3_L = GEN_LONG.next();
CON4_L = GEN_LONG.next();
}
@Run(test = {
"randomPowerOfTwoAddition",
"randomPowerOfTwoAdditionL"
})
private void runRandomPowerOfTwoAddition() {
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, GEN_INT.next() }) {
Asserts.assertEQ(a * (CON1_I + CON2_I + CON3_I + CON4_I), randomPowerOfTwoAddition(a));
}
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, GEN_LONG.next() }) {
Asserts.assertEQ(a * (CON1_L + CON2_L + CON3_L + CON4_L), randomPowerOfTwoAdditionL(a));
}
}
// We can't do IR verification but only check for correctness for a better confidence.
@Test
private static int randomPowerOfTwoAddition(int a) {
return a * CON1_I + a * CON2_I + a * CON3_I + a * CON4_I;
}
@Test
private static long randomPowerOfTwoAdditionL(long a) {
return a * CON1_L + a * CON2_L + a * CON3_L + a * CON4_L;
}
// Patterns that are originally cannot be recognized due to their right precedence making it difficult without
// recursion, but some are made possible with swapping lhs and rhs.
@Run(test = {
"rightPrecedence",
"rightPrecedenceL",
"rightPrecedenceShift",
"rightPrecedenceShiftL",
})
private void runLhsRhsSwaps() {
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, GEN_INT.next() }) {
Asserts.assertEQ(a * 3, rightPrecedence(a));
Asserts.assertEQ(a * 4, rightPrecedenceShift(a));
}
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, GEN_LONG.next() }) {
Asserts.assertEQ(a * 3, rightPrecedenceL(a));
Asserts.assertEQ(a * 4, rightPrecedenceShiftL(a));
}
}
@Test
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
private static int rightPrecedence(int a) {
return a + (a + a);
}
@Test
@IR(counts = { IRNode.ADD_L, "1", IRNode.LSHIFT_L, "1" })
private static long rightPrecedenceL(long a) {
return a + (a + a);
}
@Test
@IR(failOn = IRNode.ADD_I)
@IR(counts = { IRNode.LSHIFT_I, "1" })
private static int rightPrecedenceShift(int a) {
return a + (a << 1) + a; // a + a*2 + a => a*2 + a + a => a*3 + a => a*4 => a<<2
}
@Test
@IR(failOn = IRNode.ADD_L)
@IR(counts = { IRNode.LSHIFT_L, "1" })
private static long rightPrecedenceShiftL(long a) {
return a + (a << 1) + a; // a + a*2 + a => a*2 + a + a => a*3 + a => a*4 => a<<2
}
// JDK-8347555 only aims to cover cases minimally needed for patterns a + a + ... + a => n*a. However, some patterns
// like CON * a + a => (CON + 1) * a are considered unintended side-effects due to the way pattern matching is
// implemented.
//
// The followings are patterns that could be, mathematically speaking, optimized, but not implemented at this stage.
// These tests are to be updated if they are addressed in the future.
@Test
@IR(counts = { IRNode.ADD_I, "2", IRNode.LSHIFT_I, "2" })
@Arguments(values = { Argument.RANDOM_EACH })
private static int complexShiftPattern(int a) {
return a + (a << 1) + (a << 2); // This could've been: a + a*2 + a*4 => a*7
}
@Test
@IR(counts = { IRNode.ADD_I, "2" }) // b = a + a, c = b + b
@Arguments(values = { Argument.RANDOM_EACH })
private static int nestedAddPattern(int a) {
return (a + a) + (a + a); // This could've been: 2*a + 2*a => 4*a
}
@Test
@IR(counts = { IRNode.ADD_I, "3", IRNode.LSHIFT_I, "1" })
@Arguments(values = { Argument.RANDOM_EACH })
private static int complexPrecedence(int a) {
return a + a + ((a + a) + a); // This could've been: 2*a + (2*a + a) => 2*a + 3*a => 5*a
}
@Test
@IR(counts = { IRNode.ADD_L, "2", IRNode.LSHIFT_L, "2" })
@Arguments(values = { Argument.RANDOM_EACH })
private static long complexShiftPatternL(long a) {
return a + (a << 1) + (a << 2); // This could've been: a + a*2 + a*4 => a*7
}
@Test
@IR(counts = { IRNode.ADD_L, "2" }) // b = a + a, c = b + b
@Arguments(values = { Argument.RANDOM_EACH })
private static long nestedAddPatternL(long a) {
return (a + a) + (a + a); // This could've been: 2*a + 2*a => 4*a
}
@Test
@IR(counts = { IRNode.ADD_L, "3", IRNode.LSHIFT_L, "1" })
@Arguments(values = { Argument.RANDOM_EACH })
private static long complexPrecedenceL(long a) {
return a + a + ((a + a) + a); // This could've been: 2*a + (2*a + a) => 2*a + 3*a => 5*a
}
}

View File

@ -0,0 +1,202 @@
/*
* Copyright (c) 2025, Red Hat and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.vm.compiler;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import java.util.concurrent.TimeUnit;
/**
* Tests speed of adding a series of additions of the same operand.
*/
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 4, time = 2, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 4, time = 2, timeUnit = TimeUnit.SECONDS)
@Fork(value = 3)
public class SerialAdditions {
private int a = 0xBADB0BA;
private long b = 0x900dba51l;
@Benchmark
public int addIntsTo02() {
return a + a; // baseline, still a + a
}
@Benchmark
public int addIntsTo04() {
return a + a + a + a; // a*4 => a<<2
}
@Benchmark
public int addIntsTo05() {
return a + a + a + a + a; // a*5 => (a<<2) + a
}
@Benchmark
public int addIntsTo06() {
return a + a + a + a + a + a; // a*6 => (a<<1) + (a<<2)
}
@Benchmark
public int addIntsTo08() {
return a + a + a + a + a + a + a + a; // a*8 => a<<3
}
@Benchmark
public int addIntsTo16() {
return a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a; // a*16 => a<<4
}
@Benchmark
public int addIntsTo23() {
return a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a; // a*23
}
@Benchmark
public int addIntsTo32() {
return a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a; // a*32 => a<<5
}
@Benchmark
public int addIntsTo42() {
return a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a; // a*42
}
@Benchmark
public int addIntsTo64() {
return a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a + a + a + a + a + a + a //
+ a + a + a + a; // 64 * a => a << 6
}
@Benchmark
public void addIntsMixed(Blackhole blackhole) {
blackhole.consume(addIntsTo02());
blackhole.consume(addIntsTo04());
blackhole.consume(addIntsTo05());
blackhole.consume(addIntsTo06());
blackhole.consume(addIntsTo08());
blackhole.consume(addIntsTo16());
blackhole.consume(addIntsTo23());
blackhole.consume(addIntsTo32());
blackhole.consume(addIntsTo42());
blackhole.consume(addIntsTo64());
}
@Benchmark
public long addLongsTo02() {
return b + b; // baseline, still a + a
}
@Benchmark
public long addLongsTo04() {
return b + b + b + b; // a*4 => a<<2
}
@Benchmark
public long addLongsTo05() {
return b + b + b + b + b; // a*5 => (a<<2) + a
}
@Benchmark
public long addLongsTo06() {
return b + b + b + b + b + b; // a*6 => (a<<1) + (a<<2)
}
@Benchmark
public long addLongsTo08() {
return b + b + b + b + b + b + b + b; // a*8 => a<<3
}
@Benchmark
public long addLongsTo16() {
return b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b; // a*16 => a<<4
}
@Benchmark
public long addLongsTo23() {
return b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b; // a*23
}
@Benchmark
public long addLongsTo32() {
return b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b; // a*32 => a<<5
}
@Benchmark
public long addLongsTo42() {
return b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b; // a*42
}
@Benchmark
public long addLongsTo64() {
return b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b + b + b + b + b + b + b //
+ b + b + b + b; // 64 * a => a << 6
}
@Benchmark
public void addLongsMixed(Blackhole blackhole) {
blackhole.consume(addLongsTo02());
blackhole.consume(addLongsTo04());
blackhole.consume(addLongsTo05());
blackhole.consume(addLongsTo06());
blackhole.consume(addLongsTo08());
blackhole.consume(addLongsTo16());
blackhole.consume(addLongsTo23());
blackhole.consume(addLongsTo32());
blackhole.consume(addLongsTo42());
blackhole.consume(addLongsTo64());
}
}