mirror of
https://github.com/openjdk/jdk.git
synced 2026-01-28 12:09:14 +00:00
8347555: [REDO] C2: implement optimization for series of Add of unique value
Reviewed-by: roland, epeter
This commit is contained in:
parent
5594d6bc88
commit
f6d77cb332
@ -396,9 +396,182 @@ Node* AddNode::IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt) {
|
||||
}
|
||||
}
|
||||
|
||||
// Collapse addition of the same terms into multiplications.
|
||||
Node* collapsed = Ideal_collapse_variable_times_con(phase, bt);
|
||||
if (collapsed != nullptr) {
|
||||
return collapsed; // Skip AddNode::Ideal() since it may now be a multiplication node.
|
||||
}
|
||||
|
||||
return AddNode::Ideal(phase, can_reshape);
|
||||
}
|
||||
|
||||
// Try to collapse addition of the same terms into a single multiplication. On success, a new MulNode is returned.
|
||||
// Examples of this conversion includes:
|
||||
// - a + a + ... + a => CON*a
|
||||
// - (a * CON) + a => (CON + 1) * a
|
||||
// - a + (a * CON) => (CON + 1) * a
|
||||
//
|
||||
// We perform such conversions incrementally during IGVN by transforming left most nodes first and work up to the root
|
||||
// of the expression. In other words, we convert, at each iteration:
|
||||
// a + a + a + ... + a
|
||||
// => 2*a + a + ... + a
|
||||
// => 3*a + ... + a
|
||||
// => n*a
|
||||
//
|
||||
// Due to the iterative nature of IGVN, MulNode transformed from first few AddNode terms may be further transformed into
|
||||
// power-of-2 pattern. (e.g., 2 * a => a << 1, 3 * a => (a << 2) + a). We can't guarantee we'll always pick up
|
||||
// transformed power-of-2 patterns when term `a` is complex.
|
||||
//
|
||||
// Note this also converts, for example, original expression `(a*3) + a` into `4*a` and `(a<<2) + a` into `5*a`. A more
|
||||
// generalized pattern `(a*b) + (a*c)` into `a*(b + c)` is handled by AddNode::IdealIL().
|
||||
Node* AddNode::Ideal_collapse_variable_times_con(PhaseGVN* phase, BasicType bt) {
|
||||
// We need to make sure that the current AddNode is not part of a MulNode that has already been optimized to a
|
||||
// power-of-2 addition (e.g., 3 * a => (a << 2) + a). Without this check, GVN would keep trying to optimize the same
|
||||
// node and can't progress. For example, 3 * a => (a << 2) + a => 3 * a => (a << 2) + a => ...
|
||||
if (Multiplication::find_power_of_two_addition_pattern(this, bt).is_valid()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Node* lhs = in(1);
|
||||
Node* rhs = in(2);
|
||||
|
||||
Multiplication mul = Multiplication::find_collapsible_addition_patterns(lhs, rhs, bt);
|
||||
if (!mul.is_valid_with(rhs)) {
|
||||
// Swap lhs and rhs then try again
|
||||
mul = Multiplication::find_collapsible_addition_patterns(rhs, lhs, bt);
|
||||
if (!mul.is_valid_with(lhs)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Node* con;
|
||||
if (bt == T_INT) {
|
||||
con = phase->intcon(java_add(static_cast<jint>(mul.multiplier()), 1));
|
||||
} else {
|
||||
con = phase->longcon(java_add(mul.multiplier(), CONST64(1)));
|
||||
}
|
||||
|
||||
return MulNode::make(con, mul.variable(), bt);
|
||||
}
|
||||
|
||||
// Find a pattern of collapsable additions that can be converted to a multiplication.
|
||||
// When matching the LHS `a * CON`, we match with best efforts by looking for the following patterns:
|
||||
// - (1) Simple addition: LHS = a + a
|
||||
// - (2) Simple lshift: LHS = a << CON
|
||||
// - (3) Simple multiplication: LHS = CON * a
|
||||
// - (4) Power-of-two addition: LHS = (a << CON1) + (a << CON2)
|
||||
AddNode::Multiplication AddNode::Multiplication::find_collapsible_addition_patterns(const Node* a, const Node* pattern, BasicType bt) {
|
||||
// (1) Simple addition pattern (e.g., lhs = a + a)
|
||||
Multiplication mul = find_simple_addition_pattern(a, bt);
|
||||
if (mul.is_valid_with(pattern)) {
|
||||
return mul;
|
||||
}
|
||||
|
||||
// (2) Simple lshift pattern (e.g., lhs = a << CON)
|
||||
mul = find_simple_lshift_pattern(a, bt);
|
||||
if (mul.is_valid_with(pattern)) {
|
||||
return mul;
|
||||
}
|
||||
|
||||
// (3) Simple multiplication pattern (e.g., lhs = CON * a)
|
||||
mul = find_simple_multiplication_pattern(a, bt);
|
||||
if (mul.is_valid_with(pattern)) {
|
||||
return mul;
|
||||
}
|
||||
|
||||
// (4) Power-of-two addition pattern (e.g., lhs = (a << CON1) + (a << CON2))
|
||||
// While multiplications can be potentially optimized to power-of-2 subtractions (e.g., a * 7 => (a << 3) - a),
|
||||
// (x - y) + y => x is already handled by the Identity() methods. So, we don't need to check for that pattern here.
|
||||
mul = find_power_of_two_addition_pattern(a, bt);
|
||||
if (mul.is_valid_with(pattern)) {
|
||||
return mul;
|
||||
}
|
||||
|
||||
// We've tried everything.
|
||||
return make_invalid();
|
||||
}
|
||||
|
||||
// Try to match `n = a + a`. On success, return a struct with `.valid = true`, `variable = a`, and `multiplier = 2`.
|
||||
// The method matches `n` for pattern: a + a.
|
||||
AddNode::Multiplication AddNode::Multiplication::find_simple_addition_pattern(const Node* n, BasicType bt) {
|
||||
if (n->Opcode() == Op_Add(bt) && n->in(1) == n->in(2)) {
|
||||
return Multiplication(n->in(1), 2);
|
||||
}
|
||||
|
||||
return make_invalid();
|
||||
}
|
||||
|
||||
// Try to match `n = a << CON`. On success, return a struct with `.valid = true`, `variable = a`, and
|
||||
// `multiplier = 1 << CON`.
|
||||
// Match `n` for pattern: a << CON.
|
||||
// Note that the power-of-2 multiplication optimization could potentially convert a MulNode to this pattern.
|
||||
AddNode::Multiplication AddNode::Multiplication::find_simple_lshift_pattern(const Node* n, BasicType bt) {
|
||||
// Note that power-of-2 multiplication optimization could potentially convert a MulNode to this pattern
|
||||
if (n->Opcode() == Op_LShift(bt) && n->in(2)->is_Con()) {
|
||||
Node* con = n->in(2);
|
||||
if (!con->is_top()) {
|
||||
return Multiplication(n->in(1), java_shift_left(1, con->get_int(), bt));
|
||||
}
|
||||
}
|
||||
|
||||
return make_invalid();
|
||||
}
|
||||
|
||||
// Try to match `n = CON * a`. On success, return a struct with `.valid = true`, `variable = a`, and `multiplier = CON`.
|
||||
// Match `n` for patterns: CON * a
|
||||
// Note that `CON` will always be the second input node of a Mul node canonicalized by Ideal(). If this is not the case,
|
||||
// `n` has not been processed by iGVN. So we skip the optimization for the current add node and wait for to be added to
|
||||
// the queue again.
|
||||
AddNode::Multiplication AddNode::Multiplication::find_simple_multiplication_pattern(const Node* n, BasicType bt) {
|
||||
if (n->Opcode() == Op_Mul(bt) && n->in(2)->is_Con()) {
|
||||
Node* con = n->in(2);
|
||||
Node* base = n->in(1);
|
||||
|
||||
if (!con->is_top()) {
|
||||
return Multiplication(base, con->get_integer_as_long(bt));
|
||||
}
|
||||
}
|
||||
|
||||
return make_invalid();
|
||||
}
|
||||
|
||||
// Try to match `n = (a << CON1) + (a << CON2)`. On success, return a struct with `.valid = true`, `variable = a`, and
|
||||
// `multiplier = (1 << CON1) + (1 << CON2)`.
|
||||
// Match `n` for patterns:
|
||||
// - (1) (a << CON) + (a << CON)
|
||||
// - (2) (a << CON) + a
|
||||
// - (3) a + (a << CON)
|
||||
// - (4) a + a
|
||||
// Note that one or both of the term of the addition could simply be `a` (i.e., a << 0) as in pattern (4).
|
||||
AddNode::Multiplication AddNode::Multiplication::find_power_of_two_addition_pattern(const Node* n, BasicType bt) {
|
||||
if (n->Opcode() == Op_Add(bt) && n->in(1) != n->in(2)) {
|
||||
const Multiplication lhs = find_simple_lshift_pattern(n->in(1), bt);
|
||||
const Multiplication rhs = find_simple_lshift_pattern(n->in(2), bt);
|
||||
|
||||
// Pattern (1)
|
||||
{
|
||||
const Multiplication res = lhs.add(rhs);
|
||||
if (res.is_valid()) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern (2)
|
||||
if (lhs.is_valid_with(n->in(2))) {
|
||||
return Multiplication(lhs.variable(), java_add(lhs.multiplier(), CONST64(1)));
|
||||
}
|
||||
|
||||
// Pattern (3)
|
||||
if (rhs.is_valid_with(n->in(1))) {
|
||||
return Multiplication(rhs.variable(), java_add(rhs.multiplier(), CONST64(1)));
|
||||
}
|
||||
|
||||
// Pattern (4), which is equivalent to a simple addition pattern
|
||||
return find_simple_addition_pattern(n, bt);
|
||||
}
|
||||
|
||||
return make_invalid();
|
||||
}
|
||||
|
||||
Node* AddINode::Ideal(PhaseGVN* phase, bool can_reshape) {
|
||||
Node* in1 = in(1);
|
||||
|
||||
@ -42,7 +42,51 @@ typedef const Pair<Node*, jint> ConstAddOperands;
|
||||
// by virtual functions.
|
||||
class AddNode : public Node {
|
||||
virtual uint hash() const;
|
||||
public:
|
||||
|
||||
class Multiplication {
|
||||
bool _is_valid = false;
|
||||
|
||||
Node* _variable = nullptr;
|
||||
jlong _multiplier = 0;
|
||||
|
||||
private:
|
||||
Multiplication() {}
|
||||
|
||||
public:
|
||||
Multiplication(Node* variable, jlong multiplier) :
|
||||
_is_valid(true),
|
||||
_variable(variable),
|
||||
_multiplier(multiplier) {}
|
||||
|
||||
static Multiplication make_invalid() {
|
||||
static Multiplication invalid = Multiplication();
|
||||
return invalid;
|
||||
}
|
||||
|
||||
static Multiplication find_collapsible_addition_patterns(const Node* a, const Node* pattern, BasicType bt);
|
||||
static Multiplication find_simple_addition_pattern(const Node* n, BasicType bt);
|
||||
static Multiplication find_simple_lshift_pattern(const Node* n, BasicType bt);
|
||||
static Multiplication find_simple_multiplication_pattern(const Node* n, BasicType bt);
|
||||
static Multiplication find_power_of_two_addition_pattern(const Node* n, BasicType bt);
|
||||
|
||||
Multiplication add(const Multiplication rhs) const {
|
||||
if (is_valid_with(rhs.variable()) && rhs.is_valid_with(variable())) {
|
||||
return Multiplication(variable(), java_add(multiplier(), rhs.multiplier()));
|
||||
}
|
||||
|
||||
return make_invalid();
|
||||
}
|
||||
|
||||
bool is_valid() const { return _is_valid; }
|
||||
bool is_valid_with(const Node* variable) const {
|
||||
return _is_valid && this->_variable == variable;
|
||||
}
|
||||
|
||||
Node* variable() const { return _variable; }
|
||||
jlong multiplier() const { return _multiplier; }
|
||||
};
|
||||
|
||||
public:
|
||||
AddNode( Node *in1, Node *in2 ) : Node(nullptr,in1,in2) {
|
||||
init_class_id(Class_Add);
|
||||
}
|
||||
@ -55,6 +99,7 @@ public:
|
||||
// and flatten expressions (so that 1+x+2 becomes x+3).
|
||||
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
|
||||
Node* IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt);
|
||||
Node* Ideal_collapse_variable_times_con(PhaseGVN* phase, BasicType bt);
|
||||
|
||||
// Compute a new Type for this node. Basically we just do the pre-check,
|
||||
// then call the virtual add() to set the type.
|
||||
|
||||
@ -1253,6 +1253,24 @@ JAVA_INTEGER_SHIFT_OP(>>, java_shift_right_unsigned, jlong, julong)
|
||||
|
||||
#undef JAVA_INTEGER_SHIFT_OP
|
||||
|
||||
// Some convenient bit shift operations that accepts a BasicType as the last
|
||||
// argument. These avoid potential mistakes with overloaded functions only
|
||||
// distinguished by lhs argument type.
|
||||
#define JAVA_INTEGER_SHIFT_BASIC_TYPE(FUNC) \
|
||||
inline jlong FUNC(jlong lhs, jint rhs, BasicType bt) { \
|
||||
if (bt == T_INT) { \
|
||||
return FUNC((jint) lhs, rhs); \
|
||||
} \
|
||||
assert(bt == T_LONG, "unsupported basic type"); \
|
||||
return FUNC(lhs, rhs); \
|
||||
}
|
||||
|
||||
JAVA_INTEGER_SHIFT_BASIC_TYPE(java_shift_left)
|
||||
JAVA_INTEGER_SHIFT_BASIC_TYPE(java_shift_right)
|
||||
JAVA_INTEGER_SHIFT_BASIC_TYPE(java_shift_right_unsigned)
|
||||
|
||||
#undef JAVA_INTERGER_SHIFT_BASIC_TYPE
|
||||
|
||||
//----------------------------------------------------------------------------------------------------
|
||||
// The goal of this code is to provide saturating operations for int/uint.
|
||||
// Checks overflow conditions and saturates the result to min_jint/max_jint.
|
||||
|
||||
@ -192,6 +192,7 @@ TEST(TestJavaArithmetic, shift_left_jint) {
|
||||
const volatile ShiftOpJintData* data = asl_jint_data;
|
||||
for (size_t i = 0; i < ARRAY_SIZE(asl_jint_data); ++i) {
|
||||
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift));
|
||||
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift, T_INT));
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,6 +200,7 @@ TEST(TestJavaArithmetic, shift_left_jlong) {
|
||||
const volatile ShiftOpJlongData* data = asl_jlong_data;
|
||||
for (size_t i = 0; i < ARRAY_SIZE(asl_jlong_data); ++i) {
|
||||
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift));
|
||||
ASSERT_EQ(data[i].r, java_shift_left(data[i].x, data[i].shift, T_LONG));
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,6 +264,7 @@ TEST(TestJavaArithmetic, shift_right_jint) {
|
||||
const volatile ShiftOpJintData* data = asr_jint_data;
|
||||
for (size_t i = 0; i < ARRAY_SIZE(asr_jint_data); ++i) {
|
||||
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift));
|
||||
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift, T_INT));
|
||||
}
|
||||
}
|
||||
|
||||
@ -269,6 +272,7 @@ TEST(TestJavaArithmetic, shift_right_jlong) {
|
||||
const volatile ShiftOpJlongData* data = asr_jlong_data;
|
||||
for (size_t i = 0; i < ARRAY_SIZE(asr_jlong_data); ++i) {
|
||||
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift));
|
||||
ASSERT_EQ(data[i].r, java_shift_right(data[i].x, data[i].shift, T_LONG));
|
||||
}
|
||||
}
|
||||
|
||||
@ -334,6 +338,7 @@ TEST(TestJavaArithmetic, shift_right_unsigned_jint) {
|
||||
const volatile ShiftOpJintData* data = lsr_jint_data;
|
||||
for (size_t i = 0; i < ARRAY_SIZE(lsr_jint_data); ++i) {
|
||||
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift));
|
||||
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift, T_INT));
|
||||
}
|
||||
}
|
||||
|
||||
@ -341,5 +346,6 @@ TEST(TestJavaArithmetic, shift_right_unsigned_jlong) {
|
||||
const volatile ShiftOpJlongData* data = lsr_jlong_data;
|
||||
for (size_t i = 0; i < ARRAY_SIZE(lsr_jlong_data); ++i) {
|
||||
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift));
|
||||
ASSERT_EQ(data[i].r, java_shift_right_unsigned(data[i].x, data[i].shift, T_LONG));
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,430 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Red Hat and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
|
||||
package compiler.c2.gvn;
|
||||
|
||||
import compiler.lib.generators.Generators;
|
||||
import compiler.lib.generators.RestrictableGenerator;
|
||||
import compiler.lib.ir_framework.Test;
|
||||
import compiler.lib.ir_framework.*;
|
||||
import jdk.test.lib.Asserts;
|
||||
import jdk.test.lib.Utils;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/*
|
||||
* @test
|
||||
* @bug 8325495 8347555
|
||||
* @summary C2 should optimize addition of the same terms by collapsing them into one multiplication.
|
||||
* @library /test/lib /
|
||||
* @run driver compiler.c2.gvn.TestCollapsingSameTermAdditions
|
||||
*/
|
||||
public class TestCollapsingSameTermAdditions {
|
||||
private static final RestrictableGenerator<Integer> GEN_INT = Generators.G.ints();
|
||||
private static final RestrictableGenerator<Long> GEN_LONG = Generators.G.longs();
|
||||
|
||||
public static void main(String[] args) {
|
||||
TestFramework.run();
|
||||
}
|
||||
|
||||
@Run(test = {
|
||||
"addTo2",
|
||||
"addTo3",
|
||||
"addTo4",
|
||||
"shiftAndAddTo4",
|
||||
"mulAndAddTo4",
|
||||
"addTo5",
|
||||
"addTo6",
|
||||
"addTo7",
|
||||
"addTo8",
|
||||
"addTo16",
|
||||
"addAndShiftTo16",
|
||||
"addTo42",
|
||||
"mulAndAddTo42",
|
||||
"mulAndAddToMax",
|
||||
"mulAndAddToOverflow",
|
||||
"mulAndAddToZero",
|
||||
"mulAndAddToMinus1",
|
||||
"mulAndAddToMinus42"
|
||||
})
|
||||
private void runIntTests() {
|
||||
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, GEN_INT.next() }) {
|
||||
Asserts.assertEQ(a * 2, addTo2(a));
|
||||
Asserts.assertEQ(a * 3, addTo3(a));
|
||||
Asserts.assertEQ(a * 4, addTo4(a));
|
||||
Asserts.assertEQ(a * 4, shiftAndAddTo4(a));
|
||||
Asserts.assertEQ(a * 4, mulAndAddTo4(a));
|
||||
Asserts.assertEQ(a * 5, addTo5(a));
|
||||
Asserts.assertEQ(a * 6, addTo6(a));
|
||||
Asserts.assertEQ(a * 7, addTo7(a));
|
||||
Asserts.assertEQ(a * 8, addTo8(a));
|
||||
Asserts.assertEQ(a * 16, addTo16(a));
|
||||
Asserts.assertEQ(a * 16, addAndShiftTo16(a));
|
||||
Asserts.assertEQ(a * 42, addTo42(a));
|
||||
Asserts.assertEQ(a * 42, mulAndAddTo42(a));
|
||||
Asserts.assertEQ(a * Integer.MAX_VALUE, mulAndAddToMax(a));
|
||||
Asserts.assertEQ(a * Integer.MIN_VALUE, mulAndAddToOverflow(a));
|
||||
Asserts.assertEQ(0, mulAndAddToZero(a));
|
||||
Asserts.assertEQ(a * -1, mulAndAddToMinus1(a));
|
||||
Asserts.assertEQ(a * -42, mulAndAddToMinus42(a));
|
||||
}
|
||||
}
|
||||
|
||||
@Run(test = {
|
||||
"mulAndAddToIntOverflowL",
|
||||
"mulAndAddToMaxL",
|
||||
"mulAndAddToOverflowL"
|
||||
})
|
||||
private void runLongTests() {
|
||||
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, GEN_LONG.next() }) {
|
||||
Asserts.assertEQ(a * (Integer.MAX_VALUE + 1L), mulAndAddToIntOverflowL(a));
|
||||
Asserts.assertEQ(a * Long.MAX_VALUE, mulAndAddToMaxL(a));
|
||||
Asserts.assertEQ(a * Long.MIN_VALUE, mulAndAddToOverflowL(a));
|
||||
}
|
||||
}
|
||||
|
||||
@Run(test = {
|
||||
"bitShiftToOverflow",
|
||||
"bitShiftToOverflowL"
|
||||
})
|
||||
private void runBitShiftTests() {
|
||||
Asserts.assertEQ(95, bitShiftToOverflow());
|
||||
Asserts.assertEQ(191L, bitShiftToOverflowL());
|
||||
}
|
||||
|
||||
// ----- integer tests -----
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1" })
|
||||
@IR(failOn = IRNode.LSHIFT_I)
|
||||
private static int addTo2(int a) {
|
||||
return a + a; // Simple additions like a + a should be kept as-is
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo3(int a) {
|
||||
return a + a + a; // a*3 => (a<<1) + a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo4(int a) {
|
||||
return a + a + a + a; // a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int shiftAndAddTo4(int a) {
|
||||
return (a << 1) + a + a; // a*2 + a + a => a*3 + a => a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int mulAndAddTo4(int a) {
|
||||
return a * 3 + a; // a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo5(int a) {
|
||||
return a + a + a + a + a; // a*5 => (a<<2) + a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "2" })
|
||||
private static int addTo6(int a) {
|
||||
return a + a + a + a + a + a; // a*6 => (a<<1) + (a<<2)
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
|
||||
private static int addTo7(int a) {
|
||||
return a + a + a + a + a + a + a; // a*7 => (a<<3) - a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo8(int a) {
|
||||
return a + a + a + a + a + a + a + a; // a*8 => a<<3
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo16(int a) {
|
||||
return a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a; // a*16 => a<<4
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addAndShiftTo16(int a) {
|
||||
return (a + a) << 3; // a<<(3 + 1) => a<<4
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.MUL_I, "1" })
|
||||
private static int addTo42(int a) {
|
||||
return a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a; // a*42
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.MUL_I, "1" })
|
||||
private static int mulAndAddTo42(int a) {
|
||||
return a * 40 + a + a; // a*41 + a => a*42
|
||||
}
|
||||
|
||||
private static final int INT_MAX_MINUS_ONE = Integer.MAX_VALUE - 1;
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
|
||||
private static int mulAndAddToMax(int a) {
|
||||
return a * INT_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - a => (a<<31) - a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int mulAndAddToOverflow(int a) {
|
||||
return a * Integer.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<31
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.CON_I, "1" })
|
||||
private static int mulAndAddToZero(int a) {
|
||||
return a * -1 + a; // 0
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
|
||||
private static int mulAndAddToMinus1(int a) {
|
||||
return a * -2 + a; // a*-1 => a - (a<<1)
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.MUL_I, "1" })
|
||||
private static int mulAndAddToMinus42(int a) {
|
||||
return a * -43 + a; // a*-42
|
||||
}
|
||||
|
||||
// --- long tests ---
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1" })
|
||||
private static long mulAndAddToIntOverflowL(long a) {
|
||||
return a * Integer.MAX_VALUE + a; // a*(INT_MAX+1)
|
||||
}
|
||||
|
||||
private static final long LONG_MAX_MINUS_ONE = Long.MAX_VALUE - 1;
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1", IRNode.SUB_L, "1" })
|
||||
private static long mulAndAddToMaxL(long a) {
|
||||
return a * LONG_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - 1 => (a<<63) - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1" })
|
||||
private static long mulAndAddToOverflowL(long a) {
|
||||
return a * Long.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<63
|
||||
}
|
||||
|
||||
// --- bit shift tests ---
|
||||
@Test
|
||||
@IR(failOn = {IRNode.ADD_I, IRNode.LSHIFT_I})
|
||||
private static int bitShiftToOverflow() {
|
||||
int i, x = 0;
|
||||
for (i = 0; i < 32; i++) {
|
||||
x = i;
|
||||
}
|
||||
|
||||
// x = 31 (phi), i = 32 (phi + 1)
|
||||
return i + (x << i) + i; // Expects 32 + 31 + 32 = 95
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = {IRNode.ADD_L, IRNode.LSHIFT_L})
|
||||
private static long bitShiftToOverflowL() {
|
||||
int i, x = 0;
|
||||
for (i = 0; i < 64; i++) {
|
||||
x = i;
|
||||
}
|
||||
|
||||
// x = 63 (phi), i = 64 (phi + 1)
|
||||
return i + (x << i) + i; // Expects 64 + 63 + 64 = 191
|
||||
}
|
||||
|
||||
// --- random tests ---
|
||||
private static final int CON1_I, CON2_I, CON3_I, CON4_I;
|
||||
private static final long CON1_L, CON2_L, CON3_L, CON4_L;
|
||||
|
||||
static {
|
||||
CON1_I = GEN_INT.next();
|
||||
CON2_I = GEN_INT.next();
|
||||
CON3_I = GEN_INT.next();
|
||||
CON4_I = GEN_INT.next();
|
||||
|
||||
CON1_L = GEN_LONG.next();
|
||||
CON2_L = GEN_LONG.next();
|
||||
CON3_L = GEN_LONG.next();
|
||||
CON4_L = GEN_LONG.next();
|
||||
}
|
||||
|
||||
@Run(test = {
|
||||
"randomPowerOfTwoAddition",
|
||||
"randomPowerOfTwoAdditionL"
|
||||
})
|
||||
private void runRandomPowerOfTwoAddition() {
|
||||
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, GEN_INT.next() }) {
|
||||
Asserts.assertEQ(a * (CON1_I + CON2_I + CON3_I + CON4_I), randomPowerOfTwoAddition(a));
|
||||
}
|
||||
|
||||
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, GEN_LONG.next() }) {
|
||||
Asserts.assertEQ(a * (CON1_L + CON2_L + CON3_L + CON4_L), randomPowerOfTwoAdditionL(a));
|
||||
}
|
||||
}
|
||||
|
||||
// We can't do IR verification but only check for correctness for a better confidence.
|
||||
@Test
|
||||
private static int randomPowerOfTwoAddition(int a) {
|
||||
return a * CON1_I + a * CON2_I + a * CON3_I + a * CON4_I;
|
||||
}
|
||||
|
||||
@Test
|
||||
private static long randomPowerOfTwoAdditionL(long a) {
|
||||
return a * CON1_L + a * CON2_L + a * CON3_L + a * CON4_L;
|
||||
}
|
||||
|
||||
// Patterns that are originally cannot be recognized due to their right precedence making it difficult without
|
||||
// recursion, but some are made possible with swapping lhs and rhs.
|
||||
@Run(test = {
|
||||
"rightPrecedence",
|
||||
"rightPrecedenceL",
|
||||
"rightPrecedenceShift",
|
||||
"rightPrecedenceShiftL",
|
||||
})
|
||||
private void runLhsRhsSwaps() {
|
||||
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, GEN_INT.next() }) {
|
||||
Asserts.assertEQ(a * 3, rightPrecedence(a));
|
||||
Asserts.assertEQ(a * 4, rightPrecedenceShift(a));
|
||||
}
|
||||
|
||||
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, GEN_LONG.next() }) {
|
||||
Asserts.assertEQ(a * 3, rightPrecedenceL(a));
|
||||
Asserts.assertEQ(a * 4, rightPrecedenceShiftL(a));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
|
||||
private static int rightPrecedence(int a) {
|
||||
return a + (a + a);
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_L, "1", IRNode.LSHIFT_L, "1" })
|
||||
private static long rightPrecedenceL(long a) {
|
||||
return a + (a + a);
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int rightPrecedenceShift(int a) {
|
||||
return a + (a << 1) + a; // a + a*2 + a => a*2 + a + a => a*3 + a => a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1" })
|
||||
private static long rightPrecedenceShiftL(long a) {
|
||||
return a + (a << 1) + a; // a + a*2 + a => a*2 + a + a => a*3 + a => a*4 => a<<2
|
||||
}
|
||||
|
||||
// JDK-8347555 only aims to cover cases minimally needed for patterns a + a + ... + a => n*a. However, some patterns
|
||||
// like CON * a + a => (CON + 1) * a are considered unintended side-effects due to the way pattern matching is
|
||||
// implemented.
|
||||
//
|
||||
// The followings are patterns that could be, mathematically speaking, optimized, but not implemented at this stage.
|
||||
// These tests are to be updated if they are addressed in the future.
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "2", IRNode.LSHIFT_I, "2" })
|
||||
@Arguments(values = { Argument.RANDOM_EACH })
|
||||
private static int complexShiftPattern(int a) {
|
||||
return a + (a << 1) + (a << 2); // This could've been: a + a*2 + a*4 => a*7
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "2" }) // b = a + a, c = b + b
|
||||
@Arguments(values = { Argument.RANDOM_EACH })
|
||||
private static int nestedAddPattern(int a) {
|
||||
return (a + a) + (a + a); // This could've been: 2*a + 2*a => 4*a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "3", IRNode.LSHIFT_I, "1" })
|
||||
@Arguments(values = { Argument.RANDOM_EACH })
|
||||
private static int complexPrecedence(int a) {
|
||||
return a + a + ((a + a) + a); // This could've been: 2*a + (2*a + a) => 2*a + 3*a => 5*a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_L, "2", IRNode.LSHIFT_L, "2" })
|
||||
@Arguments(values = { Argument.RANDOM_EACH })
|
||||
private static long complexShiftPatternL(long a) {
|
||||
return a + (a << 1) + (a << 2); // This could've been: a + a*2 + a*4 => a*7
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_L, "2" }) // b = a + a, c = b + b
|
||||
@Arguments(values = { Argument.RANDOM_EACH })
|
||||
private static long nestedAddPatternL(long a) {
|
||||
return (a + a) + (a + a); // This could've been: 2*a + 2*a => 4*a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_L, "3", IRNode.LSHIFT_L, "1" })
|
||||
@Arguments(values = { Argument.RANDOM_EACH })
|
||||
private static long complexPrecedenceL(long a) {
|
||||
return a + a + ((a + a) + a); // This could've been: 2*a + (2*a + a) => 2*a + 3*a => 5*a
|
||||
}
|
||||
}
|
||||
202
test/micro/org/openjdk/bench/vm/compiler/SerialAdditions.java
Normal file
202
test/micro/org/openjdk/bench/vm/compiler/SerialAdditions.java
Normal file
@ -0,0 +1,202 @@
|
||||
/*
|
||||
* Copyright (c) 2025, Red Hat and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
package org.openjdk.bench.vm.compiler;
|
||||
|
||||
import org.openjdk.jmh.annotations.*;
|
||||
import org.openjdk.jmh.infra.Blackhole;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Tests speed of adding a series of additions of the same operand.
|
||||
*/
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.NANOSECONDS)
|
||||
@State(Scope.Thread)
|
||||
@Warmup(iterations = 4, time = 2, timeUnit = TimeUnit.SECONDS)
|
||||
@Measurement(iterations = 4, time = 2, timeUnit = TimeUnit.SECONDS)
|
||||
@Fork(value = 3)
|
||||
public class SerialAdditions {
|
||||
private int a = 0xBADB0BA;
|
||||
private long b = 0x900dba51l;
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo02() {
|
||||
return a + a; // baseline, still a + a
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo04() {
|
||||
return a + a + a + a; // a*4 => a<<2
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo05() {
|
||||
return a + a + a + a + a; // a*5 => (a<<2) + a
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo06() {
|
||||
return a + a + a + a + a + a; // a*6 => (a<<1) + (a<<2)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo08() {
|
||||
return a + a + a + a + a + a + a + a; // a*8 => a<<3
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo16() {
|
||||
return a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a; // a*16 => a<<4
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo23() {
|
||||
return a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a; // a*23
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo32() {
|
||||
return a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a; // a*32 => a<<5
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo42() {
|
||||
return a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a; // a*42
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public int addIntsTo64() {
|
||||
return a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a + a + a + a + a + a + a //
|
||||
+ a + a + a + a; // 64 * a => a << 6
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void addIntsMixed(Blackhole blackhole) {
|
||||
blackhole.consume(addIntsTo02());
|
||||
blackhole.consume(addIntsTo04());
|
||||
blackhole.consume(addIntsTo05());
|
||||
blackhole.consume(addIntsTo06());
|
||||
blackhole.consume(addIntsTo08());
|
||||
blackhole.consume(addIntsTo16());
|
||||
blackhole.consume(addIntsTo23());
|
||||
blackhole.consume(addIntsTo32());
|
||||
blackhole.consume(addIntsTo42());
|
||||
blackhole.consume(addIntsTo64());
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo02() {
|
||||
return b + b; // baseline, still a + a
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo04() {
|
||||
return b + b + b + b; // a*4 => a<<2
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo05() {
|
||||
return b + b + b + b + b; // a*5 => (a<<2) + a
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo06() {
|
||||
return b + b + b + b + b + b; // a*6 => (a<<1) + (a<<2)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo08() {
|
||||
return b + b + b + b + b + b + b + b; // a*8 => a<<3
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo16() {
|
||||
return b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b; // a*16 => a<<4
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo23() {
|
||||
return b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b; // a*23
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo32() {
|
||||
return b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b; // a*32 => a<<5
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo42() {
|
||||
return b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b; // a*42
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public long addLongsTo64() {
|
||||
return b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b + b + b + b + b + b + b //
|
||||
+ b + b + b + b; // 64 * a => a << 6
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void addLongsMixed(Blackhole blackhole) {
|
||||
blackhole.consume(addLongsTo02());
|
||||
blackhole.consume(addLongsTo04());
|
||||
blackhole.consume(addLongsTo05());
|
||||
blackhole.consume(addLongsTo06());
|
||||
blackhole.consume(addLongsTo08());
|
||||
blackhole.consume(addLongsTo16());
|
||||
blackhole.consume(addLongsTo23());
|
||||
blackhole.consume(addLongsTo32());
|
||||
blackhole.consume(addLongsTo42());
|
||||
blackhole.consume(addLongsTo64());
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user