8341781: Improve Min/Max node identities

Reviewed-by: chagedorn
This commit is contained in:
Jasmine Karthikeyan 2024-12-10 22:32:19 +00:00
parent 4c39e9faa0
commit 29d648c642
5 changed files with 295 additions and 8 deletions

View File

@ -1413,6 +1413,20 @@ Node* MaxINode::Ideal(PhaseGVN* phase, bool can_reshape) {
return IdealI(phase, can_reshape);
}
Node* MaxINode::Identity(PhaseGVN* phase) {
const TypeInt* t1 = phase->type(in(1))->is_int();
const TypeInt* t2 = phase->type(in(2))->is_int();
// Can we determine the maximum statically?
if (t1->_lo >= t2->_hi) {
return in(1);
} else if (t2->_lo >= t1->_hi) {
return in(2);
}
return MaxNode::Identity(phase);
}
//=============================================================================
//------------------------------add_ring---------------------------------------
// Supplied function returns the sum of the inputs.
@ -1432,6 +1446,20 @@ Node* MinINode::Ideal(PhaseGVN* phase, bool can_reshape) {
return IdealI(phase, can_reshape);
}
Node* MinINode::Identity(PhaseGVN* phase) {
const TypeInt* t1 = phase->type(in(1))->is_int();
const TypeInt* t2 = phase->type(in(2))->is_int();
// Can we determine the minimum statically?
if (t1->_lo >= t2->_hi) {
return in(2);
} else if (t2->_lo >= t1->_hi) {
return in(1);
}
return MaxNode::Identity(phase);
}
//------------------------------add_ring---------------------------------------
// Supplied function returns the sum of the inputs.
const Type *MinINode::add_ring( const Type *t0, const Type *t1 ) const {
@ -1574,11 +1602,56 @@ Node* MinLNode::Ideal(PhaseGVN* phase, bool can_reshape) {
return nullptr;
}
int MaxNode::opposite_opcode() const {
if (Opcode() == max_opcode()) {
return min_opcode();
} else {
assert(Opcode() == min_opcode(), "Caller should be either %s or %s, but is %s", NodeClassNames[max_opcode()], NodeClassNames[min_opcode()], NodeClassNames[Opcode()]);
return max_opcode();
}
}
// Given a redundant structure such as Max/Min(A, Max/Min(B, C)) where A == B or A == C, return the useful part of the structure.
// 'operation' is the node expected to be the inner 'Max/Min(B, C)', and 'operand' is the node expected to be the 'A' operand of the outer node.
Node* MaxNode::find_identity_operation(Node* operation, Node* operand) {
if (operation->Opcode() == Opcode() || operation->Opcode() == opposite_opcode()) {
Node* n1 = operation->in(1);
Node* n2 = operation->in(2);
// Given Op(A, Op(B, C)), see if either A == B or A == C is true.
if (n1 == operand || n2 == operand) {
// If the operations are the same return the inner operation, as Max(A, Max(A, B)) == Max(A, B).
if (operation->Opcode() == Opcode()) {
return operation;
}
// If the operations are different return the operand 'A', as Max(A, Min(A, B)) == A if the value isn't floating point.
// With floating point values, the identity doesn't hold if B == NaN.
const Type* type = bottom_type();
if (type->isa_int() || type->isa_long()) {
return operand;
}
}
}
return nullptr;
}
Node* MaxNode::Identity(PhaseGVN* phase) {
if (in(1) == in(2)) {
return in(1);
}
Node* identity_1 = MaxNode::find_identity_operation(in(2), in(1));
if (identity_1 != nullptr) {
return identity_1;
}
Node* identity_2 = MaxNode::find_identity_operation(in(1), in(2));
if (identity_2 != nullptr) {
return identity_2;
}
return AddNode::Identity(phase);
}

View File

@ -262,8 +262,7 @@ public:
//------------------------------MaxNode----------------------------------------
// Max (or min) of 2 values. Included with the ADD nodes because it inherits
// all the behavior of addition on a ring. Only new thing is that we allow
// 2 equal inputs to be equal.
// all the behavior of addition on a ring.
class MaxNode : public AddNode {
private:
static Node* build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn);
@ -277,6 +276,8 @@ public:
virtual int min_opcode() const = 0;
Node* IdealI(PhaseGVN* phase, bool can_reshape);
virtual Node* Identity(PhaseGVN* phase);
Node* find_identity_operation(Node* operation, Node* operand);
int opposite_opcode() const;
static Node* unsigned_max(Node* a, Node* b, const Type* t, PhaseGVN& gvn) {
return build_min_max(a, b, true, true, t, gvn);
@ -321,6 +322,7 @@ public:
virtual uint ideal_reg() const { return Op_RegI; }
int max_opcode() const { return Op_MaxI; }
int min_opcode() const { return Op_MinI; }
virtual Node* Identity(PhaseGVN* phase);
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
};
@ -337,7 +339,8 @@ public:
virtual uint ideal_reg() const { return Op_RegI; }
int max_opcode() const { return Op_MaxI; }
int min_opcode() const { return Op_MinI; }
virtual Node *Ideal(PhaseGVN *phase, bool can_reshape);
virtual Node* Identity(PhaseGVN* phase);
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
};
//------------------------------MaxLNode---------------------------------------

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, Arm Limited. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
@ -28,7 +28,7 @@ import compiler.lib.ir_framework.*;
/*
* @test
* @bug 8290248 8312547
* @bug 8290248 8312547 8341781
* @summary Test that Ideal transformations of MaxINode and MinINode are
* being performed as expected.
* @library /test/lib /
@ -46,10 +46,12 @@ public class MaxMinINodeIdealizationTests {
"testMax2LNoLeftAdd",
"testMax3",
"testMax4",
"testMax5",
"testMin1",
"testMin2",
"testMin3",
"testMin4"})
"testMin4",
"testMin5"})
public void runPositiveTests() {
int a = RunInfo.getRandom().nextInt();
int min = Integer.MIN_VALUE;
@ -76,11 +78,13 @@ public class MaxMinINodeIdealizationTests {
Asserts.assertEQ(Math.max(a >> 1, ((a >> 1) + 11)) , testMax2LNoLeftAdd(a));
Asserts.assertEQ(Math.max(a, a) , testMax3(a));
Asserts.assertEQ(0 , testMax4(a));
Asserts.assertEQ(8 , testMax5(a));
Asserts.assertEQ(Math.min(((a >> 1) + 100), Math.min(((a >> 1) + 150), 200)), testMin1(a));
Asserts.assertEQ(Math.min(((a >> 1) + 10), ((a >> 1) + 11)) , testMin2(a));
Asserts.assertEQ(Math.min(a, a) , testMin3(a));
Asserts.assertEQ(0 , testMin4(a));
Asserts.assertEQ(a & 7 , testMin5(a));
}
// The transformations in test*1 and test*2 can happen only if the compiler has enough information
@ -219,6 +223,18 @@ public class MaxMinINodeIdealizationTests {
return Math.min(i, 0) > 0 ? 1 : 0;
}
@Test
@IR(failOn = {IRNode.MAX_I})
public int testMax5(int i) {
return Math.max(i & 7, 8);
}
@Test
@IR(failOn = {IRNode.MIN_I})
public int testMin5(int i) {
return Math.min(i & 7, 8);
}
@Run(test = {"testTwoLevelsDifferentXY",
"testTwoLevelsNoLeftConstant",
"testTwoLevelsNoRightConstant",

View File

@ -0,0 +1,196 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package compiler.c2.irTests;
import jdk.test.lib.Asserts;
import compiler.lib.ir_framework.*;
import java.util.Random;
import jdk.test.lib.Utils;
/*
* @test
* @bug 8341781
* @summary Test identities of MinNodes and MaxNodes.
* @key randomness
* @library /test/lib /
* @run driver compiler.c2.irTests.TestMinMaxIdentities
*/
public class TestMinMaxIdentities {
private static final Random RANDOM = Utils.getRandomInstance();
public static void main(String[] args) {
TestFramework.run();
}
@Run(test = { "intMinMin", "intMinMax", "intMaxMin", "intMaxMax",
"longMinMin", "longMinMax", "longMaxMin", "longMaxMax",
"floatMinMin", "floatMaxMax", "doubleMinMin", "doubleMaxMax",
"floatMinMax", "floatMaxMin", "doubleMinMax", "doubleMaxMin" })
public void runMethod() {
assertResult(10, 20, 10L, 20L, 10.f, 20.f, 10.0, 20.0);
assertResult(20, 10, 20L, 10L, 20.f, 10.f, 20.0, 10.0);
assertResult(RANDOM.nextInt(), RANDOM.nextInt(), RANDOM.nextLong(), RANDOM.nextLong(), RANDOM.nextFloat(), RANDOM.nextFloat(), RANDOM.nextDouble(), RANDOM.nextDouble());
assertResult(RANDOM.nextInt(), RANDOM.nextInt(), RANDOM.nextLong(), RANDOM.nextLong(), RANDOM.nextFloat(), RANDOM.nextFloat(), RANDOM.nextDouble(), RANDOM.nextDouble());
assertResult(Integer.MAX_VALUE, Integer.MIN_VALUE, Long.MAX_VALUE, Long.MIN_VALUE, Float.POSITIVE_INFINITY, Float.NaN, Double.POSITIVE_INFINITY, Double.NaN);
assertResult(Integer.MIN_VALUE, Integer.MAX_VALUE, Long.MIN_VALUE, Long.MAX_VALUE, Float.NaN, Float.POSITIVE_INFINITY, Double.NaN, Double.POSITIVE_INFINITY);
}
@DontCompile
public void assertResult(int iA, int iB, long lA, long lB, float fA, float fB, double dA, double dB) {
Asserts.assertEQ(Math.min(iA, Math.min(iA, iB)), intMinMin(iA, iB));
Asserts.assertEQ(Math.min(iA, Math.max(iA, iB)), intMinMax(iA, iB));
Asserts.assertEQ(Math.max(iA, Math.min(iA, iB)), intMaxMin(iA, iB));
Asserts.assertEQ(Math.max(iA, Math.max(iA, iB)), intMaxMax(iA, iB));
Asserts.assertEQ(Math.min(lA, Math.min(lA, lB)), longMinMin(lA, lB));
Asserts.assertEQ(Math.min(lA, Math.max(lA, lB)), longMinMax(lA, lB));
Asserts.assertEQ(Math.max(lA, Math.min(lA, lB)), longMaxMin(lA, lB));
Asserts.assertEQ(Math.max(lA, Math.max(lA, lB)), longMaxMax(lA, lB));
Asserts.assertEQ(Math.min(fA, Math.min(fA, fB)), floatMinMin(fA, fB));
Asserts.assertEQ(Math.max(fA, Math.max(fA, fB)), floatMaxMax(fA, fB));
Asserts.assertEQ(Math.min(dA, Math.min(dA, dB)), doubleMinMin(dA, dB));
Asserts.assertEQ(Math.max(dA, Math.max(dA, dB)), doubleMaxMax(dA, dB));
// Due to NaN, these identities cannot be simplified.
Asserts.assertEQ(Math.min(fA, Math.max(fA, fB)), floatMinMax(fA, fB));
Asserts.assertEQ(Math.max(fA, Math.min(fA, fB)), floatMaxMin(fA, fB));
Asserts.assertEQ(Math.min(dA, Math.max(dA, dB)), doubleMinMax(dA, dB));
Asserts.assertEQ(Math.max(dA, Math.min(dA, dB)), doubleMaxMin(dA, dB));
}
// Integers
@Test
@IR(counts = { IRNode.MIN_I, "1" })
public int intMinMin(int a, int b) {
return Math.min(a, Math.min(a, b));
}
@Test
@IR(failOn = { IRNode.MIN_I, IRNode.MAX_I })
public int intMinMax(int a, int b) {
return Math.min(a, Math.max(a, b));
}
@Test
@IR(failOn = { IRNode.MIN_I, IRNode.MAX_I })
public int intMaxMin(int a, int b) {
return Math.max(a, Math.min(a, b));
}
@Test
@IR(counts = { IRNode.MAX_I, "1" })
public int intMaxMax(int a, int b) {
return Math.max(a, Math.max(a, b));
}
// Longs
// As Math.min/max(LL) is not intrinsified, it first needs to be transformed into CMoveL and then MinL/MaxL before
// the identity can be matched. However, the outer min/max is not transformed into CMove because of the CMove cost model.
// As JDK-8307513 adds intrinsics for the methods, the tests will be updated then.
@Test
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MIN_L, "1" })
public long longMinMin(long a, long b) {
return Math.min(a, Math.min(a, b));
}
@Test
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MIN_L, "1" })
public long longMinMax(long a, long b) {
return Math.min(a, Math.max(a, b));
}
@Test
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MAX_L, "1" })
public long longMaxMin(long a, long b) {
return Math.max(a, Math.min(a, b));
}
@Test
@IR(applyIfPlatform = { "riscv64", "false" }, phase = { CompilePhase.BEFORE_MACRO_EXPANSION }, counts = { IRNode.MAX_L, "1" })
public long longMaxMax(long a, long b) {
return Math.max(a, Math.max(a, b));
}
// Floats
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_F, "1" })
public float floatMinMin(float a, float b) {
return Math.min(a, Math.min(a, b));
}
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MAX_F, "1" })
public float floatMaxMax(float a, float b) {
return Math.max(a, Math.max(a, b));
}
// Doubles
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_D, "1" })
public double doubleMinMin(double a, double b) {
return Math.min(a, Math.min(a, b));
}
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MAX_D, "1" })
public double doubleMaxMax(double a, double b) {
return Math.max(a, Math.max(a, b));
}
// Float and double identities that cannot be simplified due to NaN
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_F, "1", IRNode.MAX_F, "1" })
public float floatMinMax(float a, float b) {
return Math.min(a, Math.max(a, b));
}
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_F, "1", IRNode.MAX_F, "1" })
public float floatMaxMin(float a, float b) {
return Math.max(a, Math.min(a, b));
}
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_D, "1", IRNode.MAX_D, "1" })
public double doubleMinMax(double a, double b) {
return Math.min(a, Math.max(a, b));
}
@Test
@IR(applyIfCPUFeatureOr = {"avx", "true", "asimd", "true", "rvv", "true"}, counts = { IRNode.MIN_D, "1", IRNode.MAX_D, "1" })
public double doubleMaxMin(double a, double b) {
return Math.max(a, Math.min(a, b));
}
}

View File

@ -216,9 +216,8 @@ public class BasicShortOpTest extends VectorizationTestRunner {
@IR(failOn = {IRNode.STORE_VECTOR})
public short[] vectorMin() {
short[] res = new short[SIZE];
int val = 65536;
for (int i = 0; i < SIZE; i++) {
res[i] = (short) Math.min(a[i], val);
res[i] = (short) Math.min(a[i], b[i]);
}
return res;
}