diff --git a/src/hotspot/share/opto/addnode.cpp b/src/hotspot/share/opto/addnode.cpp index b44aa53f298..d9b62d2890d 100644 --- a/src/hotspot/share/opto/addnode.cpp +++ b/src/hotspot/share/opto/addnode.cpp @@ -395,159 +395,9 @@ Node* AddNode::IdealIL(PhaseGVN* phase, bool can_reshape, BasicType bt) { } } - // Convert a + a + ... + a into a*n - Node* serial_additions = convert_serial_additions(phase, bt); - if (serial_additions != nullptr) { - return serial_additions; - } - return AddNode::Ideal(phase, can_reshape); } -// Try to convert a serial of additions into a single multiplication. Also convert `(a * CON) + a` to `(CON + 1) * a` as -// a side effect. On success, a new MulNode is returned. -Node* AddNode::convert_serial_additions(PhaseGVN* phase, BasicType bt) { - // We need to make sure that the current AddNode is not part of a MulNode that has already been optimized to a - // power-of-2 addition (e.g., 3 * a => (a << 2) + a). Without this check, GVN would keep trying to optimize the same - // node and can't progress. For example, 3 * a => (a << 2) + a => 3 * a => (a << 2) + a => ... - if (find_power_of_two_addition_pattern(this, bt, nullptr) != nullptr) { - return nullptr; - } - - Node* in1 = in(1); - Node* in2 = in(2); - jlong multiplier; - - // While multiplications can be potentially optimized to power-of-2 subtractions (e.g., a * 7 => (a << 3) - a), - // (x - y) + y => x is already handled by the Identity() methods. So, we don't need to check for that pattern here. - if (find_simple_addition_pattern(in1, bt, &multiplier) == in2 - || find_simple_lshift_pattern(in1, bt, &multiplier) == in2 - || find_simple_multiplication_pattern(in1, bt, &multiplier) == in2 - || find_power_of_two_addition_pattern(in1, bt, &multiplier) == in2) { - multiplier++; // +1 for the in2 term - - Node* con = (bt == T_INT) - ? (Node*) phase->intcon((jint) multiplier) // intentional type narrowing to allow overflow at max_jint - : (Node*) phase->longcon(multiplier); - return MulNode::make(con, in2, bt); - } - - return nullptr; -} - -// Try to match `a + a`. On success, return `a` and set `2` as `multiplier`. -// The method matches `n` for pattern: AddNode(a, a). -Node* AddNode::find_simple_addition_pattern(Node* n, BasicType bt, jlong* multiplier) { - if (n->Opcode() == Op_Add(bt) && n->in(1) == n->in(2)) { - *multiplier = 2; - return n->in(1); - } - - return nullptr; -} - -// Try to match `a << CON`. On success, return `a` and set `1 << CON` as `multiplier`. -// Match `n` for pattern: LShiftNode(a, CON). -// Note that the power-of-2 multiplication optimization could potentially convert a MulNode to this pattern. -Node* AddNode::find_simple_lshift_pattern(Node* n, BasicType bt, jlong* multiplier) { - // Note that power-of-2 multiplication optimization could potentially convert a MulNode to this pattern - if (n->Opcode() == Op_LShift(bt) && n->in(2)->is_Con()) { - Node* con = n->in(2); - if (con->is_top()) { - return nullptr; - } - - *multiplier = ((jlong) 1 << con->get_int()); - return n->in(1); - } - - return nullptr; -} - -// Try to match `CON * a`. On success, return `a` and set `CON` as `multiplier`. -// Match `n` for patterns: -// - MulNode(CON, a) -// - MulNode(a, CON) -Node* AddNode::find_simple_multiplication_pattern(Node* n, BasicType bt, jlong* multiplier) { - // This optimization technically only produces MulNode(CON, a), but we might as match MulNode(a, CON), too. - if (n->Opcode() == Op_Mul(bt) && (n->in(1)->is_Con() || n->in(2)->is_Con())) { - Node* con = n->in(1); - Node* base = n->in(2); - - // swap ConNode to lhs for easier matching - if (!con->is_Con()) { - swap(con, base); - } - - if (con->is_top()) { - return nullptr; - } - - *multiplier = con->get_integer_as_long(bt); - return base; - } - - return nullptr; -} - -// Try to match `(a << CON1) + (a << CON2)`. On success, return `a` and set `(1 << CON1) + (1 << CON2)` as `multiplier`. -// Match `n` for patterns: -// - AddNode(LShiftNode(a, CON), LShiftNode(a, CON)/a) -// - AddNode(LShiftNode(a, CON)/a, LShiftNode(a, CON)) -// given that lhs is different from rhs. -// Note that one of the term of the addition could simply be `a` (i.e., a << 0). Calling this function with `multiplier` -// being null is safe. -Node* AddNode::find_power_of_two_addition_pattern(Node* n, BasicType bt, jlong* multiplier) { - if (n->Opcode() == Op_Add(bt) && n->in(1) != n->in(2)) { - Node* lhs = n->in(1); - Node* rhs = n->in(2); - - // swap LShiftNode to lhs for easier matching - if (lhs->Opcode() != Op_LShift(bt)) { - swap(lhs, rhs); - } - - // AddNode(LShiftNode(a, CON), *)? - if (lhs->Opcode() != Op_LShift(bt) || !lhs->in(2)->is_Con()) { - return nullptr; - } - - jlong lhs_multiplier = 0; - if (multiplier != nullptr) { - Node* con = lhs->in(2); - if (con->is_top()) { - return nullptr; - } - - lhs_multiplier = (jlong) 1 << con->get_int(); - } - - // AddNode(LShiftNode(a, CON), a)? - if (lhs->in(1) == rhs) { - if (multiplier != nullptr) { - *multiplier = lhs_multiplier + 1; - } - - return rhs; - } - - // AddNode(LShiftNode(a, CON), LShiftNode(a, CON2))? - if (rhs->Opcode() == Op_LShift(bt) && lhs->in(1) == rhs->in(1) && rhs->in(2)->is_Con()) { - if (multiplier != nullptr) { - Node* con = rhs->in(2); - if (con->is_top()) { - return nullptr; - } - - *multiplier = lhs_multiplier + ((jlong) 1 << con->get_int()); - } - - return lhs->in(1); - } - return nullptr; - } - return nullptr; -} Node* AddINode::Ideal(PhaseGVN* phase, bool can_reshape) { Node* in1 = in(1); diff --git a/src/hotspot/share/opto/addnode.hpp b/src/hotspot/share/opto/addnode.hpp index 0a2c42b7796..c409fb8cea8 100644 --- a/src/hotspot/share/opto/addnode.hpp +++ b/src/hotspot/share/opto/addnode.hpp @@ -42,13 +42,6 @@ typedef const Pair ConstAddOperands; // by virtual functions. class AddNode : public Node { virtual uint hash() const; - - Node* convert_serial_additions(PhaseGVN* phase, BasicType bt); - static Node* find_simple_addition_pattern(Node* n, BasicType bt, jlong* multiplier); - static Node* find_simple_lshift_pattern(Node* n, BasicType bt, jlong* multiplier); - static Node* find_simple_multiplication_pattern(Node* n, BasicType bt, jlong* multiplier); - static Node* find_power_of_two_addition_pattern(Node* n, BasicType bt, jlong* multiplier); - public: AddNode( Node *in1, Node *in2 ) : Node(nullptr,in1,in2) { init_class_id(Class_Add); diff --git a/test/hotspot/jtreg/compiler/c2/TestSerialAdditions.java b/test/hotspot/jtreg/compiler/c2/TestSerialAdditions.java deleted file mode 100644 index c52f17dd975..00000000000 --- a/test/hotspot/jtreg/compiler/c2/TestSerialAdditions.java +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) 2024 Red Hat and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ - -package compiler.c2; - -import compiler.lib.ir_framework.Test; -import compiler.lib.ir_framework.*; -import jdk.test.lib.Asserts; -import jdk.test.lib.Utils; - -import java.util.Random; - -/* - * @test - * @bug 8325495 - * @summary C2 should optimize for series of Add of unique value. e.g., a + a + ... + a => a*n - * @library /test/lib / - * @run driver compiler.c2.TestSerialAdditions - */ -public class TestSerialAdditions { - private static final Random RNG = Utils.getRandomInstance(); - - public static void main(String[] args) { - TestFramework.run(); - } - - @Run(test = { - "addTo2", - "addTo3", - "addTo4", - "shiftAndAddTo4", - "mulAndAddTo4", - "addTo5", - "addTo6", - "addTo7", - "addTo8", - "addTo16", - "addAndShiftTo16", - "addTo42", - "mulAndAddTo42", - "mulAndAddToMax", - "mulAndAddToOverflow", - "mulAndAddToZero", - "mulAndAddToMinus1", - "mulAndAddToMinus42" - }) - private void runIntTests() { - for (int a : new int[] { 0, 1, Integer.MIN_VALUE, Integer.MAX_VALUE, RNG.nextInt() }) { - Asserts.assertEQ(a * 2, addTo2(a)); - Asserts.assertEQ(a * 3, addTo3(a)); - Asserts.assertEQ(a * 4, addTo4(a)); - Asserts.assertEQ(a * 4, shiftAndAddTo4(a)); - Asserts.assertEQ(a * 4, mulAndAddTo4(a)); - Asserts.assertEQ(a * 5, addTo5(a)); - Asserts.assertEQ(a * 6, addTo6(a)); - Asserts.assertEQ(a * 7, addTo7(a)); - Asserts.assertEQ(a * 8, addTo8(a)); - Asserts.assertEQ(a * 16, addTo16(a)); - Asserts.assertEQ(a * 16, addAndShiftTo16(a)); - Asserts.assertEQ(a * 42, addTo42(a)); - Asserts.assertEQ(a * 42, mulAndAddTo42(a)); - Asserts.assertEQ(a * Integer.MAX_VALUE, mulAndAddToMax(a)); - Asserts.assertEQ(a * Integer.MIN_VALUE, mulAndAddToOverflow(a)); - Asserts.assertEQ(0, mulAndAddToZero(a)); - Asserts.assertEQ(a * -1, mulAndAddToMinus1(a)); - Asserts.assertEQ(a * -42, mulAndAddToMinus42(a)); - } - } - - @Run(test = { - "mulAndAddToIntOverflowL", - "mulAndAddToMaxL", - "mulAndAddToOverflowL" - }) - private void runLongTests() { - for (long a : new long[] { 0, 1, Long.MIN_VALUE, Long.MAX_VALUE, RNG.nextLong() }) { - Asserts.assertEQ(a * (Integer.MAX_VALUE + 1L), mulAndAddToIntOverflowL(a)); - Asserts.assertEQ(a * Long.MAX_VALUE, mulAndAddToMaxL(a)); - Asserts.assertEQ(a * Long.MIN_VALUE, mulAndAddToOverflowL(a)); - } - } - - // ----- integer tests ----- - @Test - @IR(counts = { IRNode.ADD_I, "1" }) - @IR(failOn = IRNode.LSHIFT_I) - private static int addTo2(int a) { - return a + a; // Simple additions like a + a should be kept as-is - } - - @Test - @IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" }) - private static int addTo3(int a) { - return a + a + a; // a*3 => (a<<1) + a - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int addTo4(int a) { - return a + a + a + a; // a*4 => a<<2 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int shiftAndAddTo4(int a) { - return (a << 1) + a + a; // a*2 + a + a => a*3 + a => a*4 => a<<2 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int mulAndAddTo4(int a) { - return a * 3 + a; // a*4 => a<<2 - } - - @Test - @IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "1" }) - private static int addTo5(int a) { - return a + a + a + a + a; // a*5 => (a<<2) + a - } - - @Test - @IR(counts = { IRNode.ADD_I, "1", IRNode.LSHIFT_I, "2" }) - private static int addTo6(int a) { - return a + a + a + a + a + a; // a*6 => (a<<1) + (a<<2) - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" }) - private static int addTo7(int a) { - return a + a + a + a + a + a + a; // a*7 => (a<<3) - a - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int addTo8(int a) { - return a + a + a + a + a + a + a + a; // a*8 => a<<3 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int addTo16(int a) { - return a + a + a + a + a + a + a + a + a + a - + a + a + a + a + a + a; // a*16 => a<<4 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int addAndShiftTo16(int a) { - return (a + a) << 3; // a<<(3 + 1) => a<<4 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.MUL_I, "1" }) - private static int addTo42(int a) { - return a + a + a + a + a + a + a + a + a + a - + a + a + a + a + a + a + a + a + a + a - + a + a + a + a + a + a + a + a + a + a - + a + a + a + a + a + a + a + a + a + a - + a + a; // a*42 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.MUL_I, "1" }) - private static int mulAndAddTo42(int a) { - return a * 40 + a + a; // a*41 + a => a*42 - } - - private static final int INT_MAX_MINUS_ONE = Integer.MAX_VALUE - 1; - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" }) - private static int mulAndAddToMax(int a) { - return a * INT_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - a => (a<<31) - a - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1" }) - private static int mulAndAddToOverflow(int a) { - return a * Integer.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<31 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.CON_I, "1" }) - private static int mulAndAddToZero(int a) { - return a * -1 + a; // 0 - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.LSHIFT_I, "1", IRNode.SUB_I, "1" }) - private static int mulAndAddToMinus1(int a) { - return a * -2 + a; // a*-1 => a - (a<<1) - } - - @Test - @IR(failOn = IRNode.ADD_I) - @IR(counts = { IRNode.MUL_I, "1" }) - private static int mulAndAddToMinus42(int a) { - return a * -43 + a; // a*-42 - } - - // --- long tests --- - @Test - @IR(failOn = IRNode.ADD_L) - @IR(counts = { IRNode.LSHIFT_L, "1" }) - private static long mulAndAddToIntOverflowL(long a) { - return a * Integer.MAX_VALUE + a; // a*(INT_MAX+1) - } - - private static final long LONG_MAX_MINUS_ONE = Long.MAX_VALUE - 1; - - @Test - @IR(failOn = IRNode.ADD_L) - @IR(counts = { IRNode.LSHIFT_L, "1", IRNode.SUB_L, "1" }) - private static long mulAndAddToMaxL(long a) { - return a * LONG_MAX_MINUS_ONE + a; // a*MAX => a*(MIN-1) => a*MIN - 1 => (a<<63) - 1 - } - - @Test - @IR(failOn = IRNode.ADD_L) - @IR(counts = { IRNode.LSHIFT_L, "1" }) - private static long mulAndAddToOverflowL(long a) { - return a * Long.MAX_VALUE + a; // a*(MAX+1) => a*(MIN) => a<<63 - } -}