mirror of
https://github.com/openjdk/jdk.git
synced 2026-02-28 11:10:26 +00:00
8347554: [BACKOUT] C2: implement optimization for series of Add of unique value
Reviewed-by: thartmann
This commit is contained in:
parent
a289bcfe7e
commit
062f2dcfe5
@ -395,159 +395,9 @@ Node* AddNode::IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt) {
|
||||
}
|
||||
}
|
||||
|
||||
// Convert a + a + ... + a into a*n
|
||||
Node* serial_additions = convert_serial_additions(phase, bt);
|
||||
if (serial_additions != nullptr) {
|
||||
return serial_additions;
|
||||
}
|
||||
|
||||
return AddNode::Ideal(phase, can_reshape);
|
||||
}
|
||||
|
||||
// Try to convert a serial of additions into a single multiplication. Also convert `(a * CON) + a` to `(CON + 1) * a` as
|
||||
// a side effect. On success, a new MulNode is returned.
|
||||
Node* AddNode::convert_serial_additions(PhaseGVN* phase, BasicType bt) {
|
||||
// We need to make sure that the current AddNode is not part of a MulNode that has already been optimized to a
|
||||
// power-of-2 addition (e.g., 3 * a => (a << 2) + a). Without this check, GVN would keep trying to optimize the same
|
||||
// node and can't progress. For example, 3 * a => (a << 2) + a => 3 * a => (a << 2) + a => ...
|
||||
if (find_power_of_two_addition_pattern(this, bt, nullptr) != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Node* in1 = in(1);
|
||||
Node* in2 = in(2);
|
||||
jlong multiplier;
|
||||
|
||||
// While multiplications can be potentially optimized to power-of-2 subtractions (e.g., a * 7 => (a << 3) - a),
|
||||
// (x - y) + y => x is already handled by the Identity() methods. So, we don't need to check for that pattern here.
|
||||
if (find_simple_addition_pattern(in1, bt, &multiplier) == in2
|
||||
|| find_simple_lshift_pattern(in1, bt, &multiplier) == in2
|
||||
|| find_simple_multiplication_pattern(in1, bt, &multiplier) == in2
|
||||
|| find_power_of_two_addition_pattern(in1, bt, &multiplier) == in2) {
|
||||
multiplier++; // +1 for the in2 term
|
||||
|
||||
Node* con = (bt == T_INT)
|
||||
? (Node*) phase->intcon((jint) multiplier) // intentional type narrowing to allow overflow at max_jint
|
||||
: (Node*) phase->longcon(multiplier);
|
||||
return MulNode::make(con, in2, bt);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Try to match `a + a`. On success, return `a` and set `2` as `multiplier`.
|
||||
// The method matches `n` for pattern: AddNode(a, a).
|
||||
Node* AddNode::find_simple_addition_pattern(Node* n, BasicType bt, jlong* multiplier) {
|
||||
if (n->Opcode() == Op_Add(bt) && n->in(1) == n->in(2)) {
|
||||
*multiplier = 2;
|
||||
return n->in(1);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Try to match `a << CON`. On success, return `a` and set `1 << CON` as `multiplier`.
|
||||
// Match `n` for pattern: LShiftNode(a, CON).
|
||||
// Note that the power-of-2 multiplication optimization could potentially convert a MulNode to this pattern.
|
||||
Node* AddNode::find_simple_lshift_pattern(Node* n, BasicType bt, jlong* multiplier) {
|
||||
// Note that power-of-2 multiplication optimization could potentially convert a MulNode to this pattern
|
||||
if (n->Opcode() == Op_LShift(bt) && n->in(2)->is_Con()) {
|
||||
Node* con = n->in(2);
|
||||
if (con->is_top()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*multiplier = ((jlong) 1 << con->get_int());
|
||||
return n->in(1);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Try to match `CON * a`. On success, return `a` and set `CON` as `multiplier`.
|
||||
// Match `n` for patterns:
|
||||
// - MulNode(CON, a)
|
||||
// - MulNode(a, CON)
|
||||
Node* AddNode::find_simple_multiplication_pattern(Node* n, BasicType bt, jlong* multiplier) {
|
||||
// This optimization technically only produces MulNode(CON, a), but we might as match MulNode(a, CON), too.
|
||||
if (n->Opcode() == Op_Mul(bt) && (n->in(1)->is_Con() || n->in(2)->is_Con())) {
|
||||
Node* con = n->in(1);
|
||||
Node* base = n->in(2);
|
||||
|
||||
// swap ConNode to lhs for easier matching
|
||||
if (!con->is_Con()) {
|
||||
swap(con, base);
|
||||
}
|
||||
|
||||
if (con->is_top()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*multiplier = con->get_integer_as_long(bt);
|
||||
return base;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Try to match `(a << CON1) + (a << CON2)`. On success, return `a` and set `(1 << CON1) + (1 << CON2)` as `multiplier`.
|
||||
// Match `n` for patterns:
|
||||
// - AddNode(LShiftNode(a, CON), LShiftNode(a, CON)/a)
|
||||
// - AddNode(LShiftNode(a, CON)/a, LShiftNode(a, CON))
|
||||
// given that lhs is different from rhs.
|
||||
// Note that one of the term of the addition could simply be `a` (i.e., a << 0). Calling this function with `multiplier`
|
||||
// being null is safe.
|
||||
Node* AddNode::find_power_of_two_addition_pattern(Node* n, BasicType bt, jlong* multiplier) {
|
||||
if (n->Opcode() == Op_Add(bt) && n->in(1) != n->in(2)) {
|
||||
Node* lhs = n->in(1);
|
||||
Node* rhs = n->in(2);
|
||||
|
||||
// swap LShiftNode to lhs for easier matching
|
||||
if (lhs->Opcode() != Op_LShift(bt)) {
|
||||
swap(lhs, rhs);
|
||||
}
|
||||
|
||||
// AddNode(LShiftNode(a, CON), *)?
|
||||
if (lhs->Opcode() != Op_LShift(bt) || !lhs->in(2)->is_Con()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
jlong lhs_multiplier = 0;
|
||||
if (multiplier != nullptr) {
|
||||
Node* con = lhs->in(2);
|
||||
if (con->is_top()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
lhs_multiplier = (jlong) 1 << con->get_int();
|
||||
}
|
||||
|
||||
// AddNode(LShiftNode(a, CON), a)?
|
||||
if (lhs->in(1) == rhs) {
|
||||
if (multiplier != nullptr) {
|
||||
*multiplier = lhs_multiplier + 1;
|
||||
}
|
||||
|
||||
return rhs;
|
||||
}
|
||||
|
||||
// AddNode(LShiftNode(a, CON), LShiftNode(a, CON2))?
|
||||
if (rhs->Opcode() == Op_LShift(bt) && lhs->in(1) == rhs->in(1) && rhs->in(2)->is_Con()) {
|
||||
if (multiplier != nullptr) {
|
||||
Node* con = rhs->in(2);
|
||||
if (con->is_top()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*multiplier = lhs_multiplier + ((jlong) 1 << con->get_int());
|
||||
}
|
||||
|
||||
return lhs->in(1);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Node* AddINode::Ideal(PhaseGVN* phase, bool can_reshape) {
|
||||
Node* in1 = in(1);
|
||||
|
||||
@ -42,13 +42,6 @@ typedef const Pair<Node*, jint> ConstAddOperands;
|
||||
// by virtual functions.
|
||||
class AddNode : public Node {
|
||||
virtual uint hash() const;
|
||||
|
||||
Node* convert_serial_additions(PhaseGVN* phase, BasicType bt);
|
||||
static Node* find_simple_addition_pattern(Node* n, BasicType bt, jlong* multiplier);
|
||||
static Node* find_simple_lshift_pattern(Node* n, BasicType bt, jlong* multiplier);
|
||||
static Node* find_simple_multiplication_pattern(Node* n, BasicType bt, jlong* multiplier);
|
||||
static Node* find_power_of_two_addition_pattern(Node* n, BasicType bt, jlong* multiplier);
|
||||
|
||||
public:
|
||||
AddNode( Node *in1, Node *in2 ) : Node(nullptr,in1,in2) {
|
||||
init_class_id(Class_Add);
|
||||
|
||||
@ -1,257 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2024 Red Hat and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
|
||||
package compiler.c2;
|
||||
|
||||
import compiler.lib.ir_framework.Test;
|
||||
import compiler.lib.ir_framework.*;
|
||||
import jdk.test.lib.Asserts;
|
||||
import jdk.test.lib.Utils;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/*
|
||||
* @test
|
||||
* @bug 8325495
|
||||
* @summary C2 should optimize for series of Add of unique value. e.g., a + a + ... + a => a*n
|
||||
* @library /test/lib /
|
||||
* @run driver compiler.c2.TestSerialAdditions
|
||||
*/
|
||||
public class TestSerialAdditions {
|
||||
private static final Random RNG = Utils.getRandomInstance();
|
||||
|
||||
public static void main(String[] args) {
|
||||
TestFramework.run();
|
||||
}
|
||||
|
||||
@Run(test = {
|
||||
"addTo2",
|
||||
"addTo3",
|
||||
"addTo4",
|
||||
"shiftAndAddTo4",
|
||||
"mulAndAddTo4",
|
||||
"addTo5",
|
||||
"addTo6",
|
||||
"addTo7",
|
||||
"addTo8",
|
||||
"addTo16",
|
||||
"addAndShiftTo16",
|
||||
"addTo42",
|
||||
"mulAndAddTo42",
|
||||
"mulAndAddToMax",
|
||||
"mulAndAddToOverflow",
|
||||
"mulAndAddToZero",
|
||||
"mulAndAddToMinus1",
|
||||
"mulAndAddToMinus42"
|
||||
})
|
||||
private void runIntTests() {
|
||||
for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, RNG.nextInt() }) {
|
||||
Asserts.assertEQ(a * 2, addTo2(a));
|
||||
Asserts.assertEQ(a * 3, addTo3(a));
|
||||
Asserts.assertEQ(a * 4, addTo4(a));
|
||||
Asserts.assertEQ(a * 4, shiftAndAddTo4(a));
|
||||
Asserts.assertEQ(a * 4, mulAndAddTo4(a));
|
||||
Asserts.assertEQ(a * 5, addTo5(a));
|
||||
Asserts.assertEQ(a * 6, addTo6(a));
|
||||
Asserts.assertEQ(a * 7, addTo7(a));
|
||||
Asserts.assertEQ(a * 8, addTo8(a));
|
||||
Asserts.assertEQ(a * 16, addTo16(a));
|
||||
Asserts.assertEQ(a * 16, addAndShiftTo16(a));
|
||||
Asserts.assertEQ(a * 42, addTo42(a));
|
||||
Asserts.assertEQ(a * 42, mulAndAddTo42(a));
|
||||
Asserts.assertEQ(a * Integer.MAX_VALUE, mulAndAddToMax(a));
|
||||
Asserts.assertEQ(a * Integer.MIN_VALUE, mulAndAddToOverflow(a));
|
||||
Asserts.assertEQ(0, mulAndAddToZero(a));
|
||||
Asserts.assertEQ(a * -1, mulAndAddToMinus1(a));
|
||||
Asserts.assertEQ(a * -42, mulAndAddToMinus42(a));
|
||||
}
|
||||
}
|
||||
|
||||
@Run(test = {
|
||||
"mulAndAddToIntOverflowL",
|
||||
"mulAndAddToMaxL",
|
||||
"mulAndAddToOverflowL"
|
||||
})
|
||||
private void runLongTests() {
|
||||
for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, RNG.nextLong() }) {
|
||||
Asserts.assertEQ(a * (Integer.MAX_VALUE + 1L), mulAndAddToIntOverflowL(a));
|
||||
Asserts.assertEQ(a * Long.MAX_VALUE, mulAndAddToMaxL(a));
|
||||
Asserts.assertEQ(a * Long.MIN_VALUE, mulAndAddToOverflowL(a));
|
||||
}
|
||||
}
|
||||
|
||||
// ----- integer tests -----
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1" })
|
||||
@IR(failOn = IRNode.LSHIFT_I)
|
||||
private static int addTo2(int a) {
|
||||
return a + a; // Simple additions like a + a should be kept as-is
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo3(int a) {
|
||||
return a + a + a; // a*3 => (a<<1) + a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo4(int a) {
|
||||
return a + a + a + a; // a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int shiftAndAddTo4(int a) {
|
||||
return (a << 1) + a + a; // a*2 + a + a => a*3 + a => a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int mulAndAddTo4(int a) {
|
||||
return a * 3 + a; // a*4 => a<<2
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo5(int a) {
|
||||
return a + a + a + a + a; // a*5 => (a<<2) + a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "2" })
|
||||
private static int addTo6(int a) {
|
||||
return a + a + a + a + a + a; // a*6 => (a<<1) + (a<<2)
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
|
||||
private static int addTo7(int a) {
|
||||
return a + a + a + a + a + a + a; // a*7 => (a<<3) - a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo8(int a) {
|
||||
return a + a + a + a + a + a + a + a; // a*8 => a<<3
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addTo16(int a) {
|
||||
return a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a; // a*16 => a<<4
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int addAndShiftTo16(int a) {
|
||||
return (a + a) << 3; // a<<(3 + 1) => a<<4
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.MUL_I, "1" })
|
||||
private static int addTo42(int a) {
|
||||
return a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a + a + a + a + a + a + a + a + a
|
||||
+ a + a; // a*42
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.MUL_I, "1" })
|
||||
private static int mulAndAddTo42(int a) {
|
||||
return a * 40 + a + a; // a*41 + a => a*42
|
||||
}
|
||||
|
||||
private static final int INT_MAX_MINUS_ONE = Integer.MAX_VALUE - 1;
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
|
||||
private static int mulAndAddToMax(int a) {
|
||||
return a * INT_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - a => (a<<31) - a
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1" })
|
||||
private static int mulAndAddToOverflow(int a) {
|
||||
return a * Integer.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<31
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.CON_I, "1" })
|
||||
private static int mulAndAddToZero(int a) {
|
||||
return a * -1 + a; // 0
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" })
|
||||
private static int mulAndAddToMinus1(int a) {
|
||||
return a * -2 + a; // a*-1 => a - (a<<1)
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_I)
|
||||
@IR(counts = { IRNode.MUL_I, "1" })
|
||||
private static int mulAndAddToMinus42(int a) {
|
||||
return a * -43 + a; // a*-42
|
||||
}
|
||||
|
||||
// --- long tests ---
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1" })
|
||||
private static long mulAndAddToIntOverflowL(long a) {
|
||||
return a * Integer.MAX_VALUE + a; // a*(INT_MAX+1)
|
||||
}
|
||||
|
||||
private static final long LONG_MAX_MINUS_ONE = Long.MAX_VALUE - 1;
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1", IRNode.SUB_L, "1" })
|
||||
private static long mulAndAddToMaxL(long a) {
|
||||
return a * LONG_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - 1 => (a<<63) - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(failOn = IRNode.ADD_L)
|
||||
@IR(counts = { IRNode.LSHIFT_L, "1" })
|
||||
private static long mulAndAddToOverflowL(long a) {
|
||||
return a * Long.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<63
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user