From d16a9b2ec507251a44f034f1ccf8039f02023d52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galder=20Zamarre=C3=B1o?= Date: Thu, 15 Jan 2026 07:22:54 +0000 Subject: [PATCH] 8373134: C2: Min/Max users of Min/Max uses should be enqueued for GVN Reviewed-by: epeter, bmaillard, dlong --- src/hotspot/share/opto/addnode.cpp | 30 +-- src/hotspot/share/opto/addnode.hpp | 46 ++--- src/hotspot/share/opto/loopnode.cpp | 18 +- src/hotspot/share/opto/macro.cpp | 4 +- src/hotspot/share/opto/movenode.cpp | 4 +- src/hotspot/share/opto/node.hpp | 3 + src/hotspot/share/opto/phaseX.cpp | 9 + src/hotspot/share/opto/vectorization.cpp | 2 +- .../compiler/igvn/TestMinMaxIdentity.java | 186 ++++++++++++++++++ 9 files changed, 251 insertions(+), 51 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/igvn/TestMinMaxIdentity.java diff --git a/src/hotspot/share/opto/addnode.cpp b/src/hotspot/share/opto/addnode.cpp index 40cd6337c17..e04da430ef0 100644 --- a/src/hotspot/share/opto/addnode.cpp +++ b/src/hotspot/share/opto/addnode.cpp @@ -1195,7 +1195,7 @@ const Type* XorLNode::Value(PhaseGVN* phase) const { return AddNode::Value(phase); } -Node* MaxNode::build_min_max_int(Node* a, Node* b, bool is_max) { +Node* MinMaxNode::build_min_max_int(Node* a, Node* b, bool is_max) { if (is_max) { return new MaxINode(a, b); } else { @@ -1203,7 +1203,7 @@ Node* MaxNode::build_min_max_int(Node* a, Node* b, bool is_max) { } } -Node* MaxNode::build_min_max_long(PhaseGVN* phase, Node* a, Node* b, bool is_max) { +Node* MinMaxNode::build_min_max_long(PhaseGVN* phase, Node* a, Node* b, bool is_max) { if (is_max) { return new MaxLNode(phase->C, a, b); } else { @@ -1211,7 +1211,7 @@ Node* MaxNode::build_min_max_long(PhaseGVN* phase, Node* a, Node* b, bool is_max } } -Node* MaxNode::build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn) { +Node* MinMaxNode::build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn) { bool is_int = gvn.type(a)->isa_int(); assert(is_int || gvn.type(a)->isa_long(), "int or long inputs"); assert(is_int == (gvn.type(b)->isa_int() != nullptr), "inconsistent inputs"); @@ -1243,7 +1243,7 @@ Node* MaxNode::build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, co return res; } -Node* MaxNode::build_min_max_diff_with_zero(Node* a, Node* b, bool is_max, const Type* t, PhaseGVN& gvn) { +Node* MinMaxNode::build_min_max_diff_with_zero(Node* a, Node* b, bool is_max, const Type* t, PhaseGVN& gvn) { bool is_int = gvn.type(a)->isa_int(); assert(is_int || gvn.type(a)->isa_long(), "int or long inputs"); assert(is_int == (gvn.type(b)->isa_int() != nullptr), "inconsistent inputs"); @@ -1290,7 +1290,7 @@ static bool can_overflow(const TypeLong* t, jlong c) { // Let = x_operands and = y_operands. // If x == y and neither add(x, x_off) nor add(y, y_off) overflow, return // add(x, op(x_off, y_off)). Otherwise, return nullptr. -Node* MaxNode::extract_add(PhaseGVN* phase, ConstAddOperands x_operands, ConstAddOperands y_operands) { +Node* MinMaxNode::extract_add(PhaseGVN* phase, ConstAddOperands x_operands, ConstAddOperands y_operands) { Node* x = x_operands.first; Node* y = y_operands.first; int opcode = Opcode(); @@ -1327,7 +1327,7 @@ static ConstAddOperands as_add_with_constant(Node* n) { return ConstAddOperands(x, c_type->is_int()->get_con()); } -Node* MaxNode::IdealI(PhaseGVN* phase, bool can_reshape) { +Node* MinMaxNode::IdealI(PhaseGVN* phase, bool can_reshape) { Node* n = AddNode::Ideal(phase, can_reshape); if (n != nullptr) { return n; @@ -1401,7 +1401,7 @@ Node* MaxINode::Identity(PhaseGVN* phase) { return in(2); } - return MaxNode::Identity(phase); + return MinMaxNode::Identity(phase); } //============================================================================= @@ -1434,7 +1434,7 @@ Node* MinINode::Identity(PhaseGVN* phase) { return in(1); } - return MaxNode::Identity(phase); + return MinMaxNode::Identity(phase); } //------------------------------add_ring--------------------------------------- @@ -1564,7 +1564,7 @@ Node* MaxLNode::Identity(PhaseGVN* phase) { return in(2); } - return MaxNode::Identity(phase); + return MinMaxNode::Identity(phase); } Node* MaxLNode::Ideal(PhaseGVN* phase, bool can_reshape) { @@ -1596,7 +1596,7 @@ Node* MinLNode::Identity(PhaseGVN* phase) { return in(1); } - return MaxNode::Identity(phase); + return MinMaxNode::Identity(phase); } Node* MinLNode::Ideal(PhaseGVN* phase, bool can_reshape) { @@ -1610,7 +1610,7 @@ Node* MinLNode::Ideal(PhaseGVN* phase, bool can_reshape) { return nullptr; } -int MaxNode::opposite_opcode() const { +int MinMaxNode::opposite_opcode() const { if (Opcode() == max_opcode()) { return min_opcode(); } else { @@ -1621,7 +1621,7 @@ int MaxNode::opposite_opcode() const { // Given a redundant structure such as Max/Min(A, Max/Min(B, C)) where A == B or A == C, return the useful part of the structure. // 'operation' is the node expected to be the inner 'Max/Min(B, C)', and 'operand' is the node expected to be the 'A' operand of the outer node. -Node* MaxNode::find_identity_operation(Node* operation, Node* operand) { +Node* MinMaxNode::find_identity_operation(Node* operation, Node* operand) { if (operation->Opcode() == Opcode() || operation->Opcode() == opposite_opcode()) { Node* n1 = operation->in(1); Node* n2 = operation->in(2); @@ -1645,17 +1645,17 @@ Node* MaxNode::find_identity_operation(Node* operation, Node* operand) { return nullptr; } -Node* MaxNode::Identity(PhaseGVN* phase) { +Node* MinMaxNode::Identity(PhaseGVN* phase) { if (in(1) == in(2)) { return in(1); } - Node* identity_1 = MaxNode::find_identity_operation(in(2), in(1)); + Node* identity_1 = MinMaxNode::find_identity_operation(in(2), in(1)); if (identity_1 != nullptr) { return identity_1; } - Node* identity_2 = MaxNode::find_identity_operation(in(1), in(2)); + Node* identity_2 = MinMaxNode::find_identity_operation(in(1), in(2)); if (identity_2 != nullptr) { return identity_2; } diff --git a/src/hotspot/share/opto/addnode.hpp b/src/hotspot/share/opto/addnode.hpp index 1bbdae92e48..4151ab5d065 100644 --- a/src/hotspot/share/opto/addnode.hpp +++ b/src/hotspot/share/opto/addnode.hpp @@ -324,14 +324,16 @@ public: //------------------------------MaxNode---------------------------------------- // Max (or min) of 2 values. Included with the ADD nodes because it inherits // all the behavior of addition on a ring. -class MaxNode : public AddNode { +class MinMaxNode : public AddNode { private: static Node* build_min_max(Node* a, Node* b, bool is_max, bool is_unsigned, const Type* t, PhaseGVN& gvn); static Node* build_min_max_diff_with_zero(Node* a, Node* b, bool is_max, const Type* t, PhaseGVN& gvn); Node* extract_add(PhaseGVN* phase, ConstAddOperands x_operands, ConstAddOperands y_operands); public: - MaxNode( Node *in1, Node *in2 ) : AddNode(in1,in2) {} + MinMaxNode(Node* in1, Node* in2) : AddNode(in1, in2) { + init_class_id(Class_MinMax); + } virtual int Opcode() const = 0; virtual int max_opcode() const = 0; virtual int min_opcode() const = 0; @@ -373,9 +375,9 @@ public: //------------------------------MaxINode--------------------------------------- // Maximum of 2 integers. Included with the ADD nodes because it inherits // all the behavior of addition on a ring. -class MaxINode : public MaxNode { +class MaxINode : public MinMaxNode { public: - MaxINode( Node *in1, Node *in2 ) : MaxNode(in1,in2) {} + MaxINode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type *add_ring( const Type *, const Type * ) const; virtual const Type *add_id() const { return TypeInt::make(min_jint); } @@ -390,9 +392,9 @@ public: //------------------------------MinINode--------------------------------------- // MINimum of 2 integers. Included with the ADD nodes because it inherits // all the behavior of addition on a ring. -class MinINode : public MaxNode { +class MinINode : public MinMaxNode { public: - MinINode( Node *in1, Node *in2 ) : MaxNode(in1,in2) {} + MinINode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type *add_ring( const Type *, const Type * ) const; virtual const Type *add_id() const { return TypeInt::make(max_jint); } @@ -406,9 +408,9 @@ public: //------------------------------MaxLNode--------------------------------------- // MAXimum of 2 longs. -class MaxLNode : public MaxNode { +class MaxLNode : public MinMaxNode { public: - MaxLNode(Compile* C, Node* in1, Node* in2) : MaxNode(in1, in2) { + MaxLNode(Compile* C, Node* in1, Node* in2) : MinMaxNode(in1, in2) { init_flags(Flag_is_macro); C->add_macro_node(this); } @@ -425,9 +427,9 @@ public: //------------------------------MinLNode--------------------------------------- // MINimum of 2 longs. -class MinLNode : public MaxNode { +class MinLNode : public MinMaxNode { public: - MinLNode(Compile* C, Node* in1, Node* in2) : MaxNode(in1, in2) { + MinLNode(Compile* C, Node* in1, Node* in2) : MinMaxNode(in1, in2) { init_flags(Flag_is_macro); C->add_macro_node(this); } @@ -444,9 +446,9 @@ public: //------------------------------MaxFNode--------------------------------------- // Maximum of 2 floats. -class MaxFNode : public MaxNode { +class MaxFNode : public MinMaxNode { public: - MaxFNode(Node *in1, Node *in2) : MaxNode(in1, in2) {} + MaxFNode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type *add_ring(const Type*, const Type*) const; virtual const Type *add_id() const { return TypeF::NEG_INF; } @@ -458,9 +460,9 @@ public: //------------------------------MinFNode--------------------------------------- // Minimum of 2 floats. -class MinFNode : public MaxNode { +class MinFNode : public MinMaxNode { public: - MinFNode(Node *in1, Node *in2) : MaxNode(in1, in2) {} + MinFNode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type *add_ring(const Type*, const Type*) const; virtual const Type *add_id() const { return TypeF::POS_INF; } @@ -472,9 +474,9 @@ public: //------------------------------MaxHFNode-------------------------------------- // Maximum of 2 half floats. -class MaxHFNode : public MaxNode { +class MaxHFNode : public MinMaxNode { public: - MaxHFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} + MaxHFNode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type* add_ring(const Type*, const Type*) const; virtual const Type* add_id() const { return TypeH::NEG_INF; } @@ -486,9 +488,9 @@ public: //------------------------------MinHFNode--------------------------------------- // Minimum of 2 half floats. -class MinHFNode : public MaxNode { +class MinHFNode : public MinMaxNode { public: - MinHFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} + MinHFNode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type* add_ring(const Type*, const Type*) const; virtual const Type* add_id() const { return TypeH::POS_INF; } @@ -500,9 +502,9 @@ public: //------------------------------MaxDNode--------------------------------------- // Maximum of 2 doubles. -class MaxDNode : public MaxNode { +class MaxDNode : public MinMaxNode { public: - MaxDNode(Node *in1, Node *in2) : MaxNode(in1, in2) {} + MaxDNode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type *add_ring(const Type*, const Type*) const; virtual const Type *add_id() const { return TypeD::NEG_INF; } @@ -514,9 +516,9 @@ public: //------------------------------MinDNode--------------------------------------- // Minimum of 2 doubles. -class MinDNode : public MaxNode { +class MinDNode : public MinMaxNode { public: - MinDNode(Node *in1, Node *in2) : MaxNode(in1, in2) {} + MinDNode(Node* in1, Node* in2) : MinMaxNode(in1, in2) {} virtual int Opcode() const; virtual const Type *add_ring(const Type*, const Type*) const; virtual const Type *add_id() const { return TypeD::POS_INF; } diff --git a/src/hotspot/share/opto/loopnode.cpp b/src/hotspot/share/opto/loopnode.cpp index dacc1a1a734..8dc34af9c19 100644 --- a/src/hotspot/share/opto/loopnode.cpp +++ b/src/hotspot/share/opto/loopnode.cpp @@ -979,9 +979,9 @@ bool PhaseIdealLoop::create_loop_nest(IdealLoopTree* loop, Node_List &old_new) { Node* inner_iters_max = nullptr; if (stride_con > 0) { - inner_iters_max = MaxNode::max_diff_with_zero(limit, outer_phi, TypeInteger::bottom(bt), _igvn); + inner_iters_max = MinMaxNode::max_diff_with_zero(limit, outer_phi, TypeInteger::bottom(bt), _igvn); } else { - inner_iters_max = MaxNode::max_diff_with_zero(outer_phi, limit, TypeInteger::bottom(bt), _igvn); + inner_iters_max = MinMaxNode::max_diff_with_zero(outer_phi, limit, TypeInteger::bottom(bt), _igvn); } Node* inner_iters_limit = _igvn.integercon(iters_limit, bt); @@ -989,7 +989,7 @@ bool PhaseIdealLoop::create_loop_nest(IdealLoopTree* loop, Node_List &old_new) { // Long.MIN_VALUE to Long.MAX_VALUE for instance). Use an unsigned // min. const TypeInteger* inner_iters_actual_range = TypeInteger::make(0, iters_limit, Type::WidenMin, bt); - Node* inner_iters_actual = MaxNode::unsigned_min(inner_iters_max, inner_iters_limit, inner_iters_actual_range, _igvn); + Node* inner_iters_actual = MinMaxNode::unsigned_min(inner_iters_max, inner_iters_limit, inner_iters_actual_range, _igvn); Node* inner_iters_actual_int; if (bt == T_LONG) { @@ -1618,7 +1618,7 @@ void PhaseIdealLoop::transform_long_range_checks(int stride_con, const Node_List Node* max_jint_plus_one_long = longcon((jlong)max_jint + 1); Node* max_range = new AddLNode(max_jint_plus_one_long, L); register_new_node(max_range, entry_control); - R = MaxNode::unsigned_min(R, max_range, TypeLong::POS, _igvn); + R = MinMaxNode::unsigned_min(R, max_range, TypeLong::POS, _igvn); set_subtree_ctrl(R, true); } @@ -1717,9 +1717,9 @@ void PhaseIdealLoop::transform_long_range_checks(int stride_con, const Node_List } Node* PhaseIdealLoop::clamp(Node* R, Node* L, Node* H) { - Node* min = MaxNode::signed_min(R, H, TypeLong::LONG, _igvn); + Node* min = MinMaxNode::signed_min(R, H, TypeLong::LONG, _igvn); set_subtree_ctrl(min, true); - Node* max = MaxNode::signed_max(L, min, TypeLong::LONG, _igvn); + Node* max = MinMaxNode::signed_max(L, min, TypeLong::LONG, _igvn); set_subtree_ctrl(max, true); return max; } @@ -3485,14 +3485,14 @@ void OuterStripMinedLoopNode::adjust_strip_mined_loop(PhaseIterGVN* igvn) { // the loop body to be run for LoopStripMiningIter. Node* max = nullptr; if (stride > 0) { - max = MaxNode::max_diff_with_zero(limit, iv_phi, TypeInt::INT, *igvn); + max = MinMaxNode::max_diff_with_zero(limit, iv_phi, TypeInt::INT, *igvn); } else { - max = MaxNode::max_diff_with_zero(iv_phi, limit, TypeInt::INT, *igvn); + max = MinMaxNode::max_diff_with_zero(iv_phi, limit, TypeInt::INT, *igvn); } // sub is positive and can be larger than the max signed int // value. Use an unsigned min. Node* const_iters = igvn->intcon(scaled_iters); - Node* min = MaxNode::unsigned_min(max, const_iters, TypeInt::make(0, scaled_iters, Type::WidenMin), *igvn); + Node* min = MinMaxNode::unsigned_min(max, const_iters, TypeInt::make(0, scaled_iters, Type::WidenMin), *igvn); // min is the number of iterations for the next inner loop execution: // unsigned_min(max(limit - iv_phi, 0), scaled_iters) if stride > 0 // unsigned_min(max(iv_phi - limit, 0), scaled_iters) if stride < 0 diff --git a/src/hotspot/share/opto/macro.cpp b/src/hotspot/share/opto/macro.cpp index 80818a4ddc7..4df03714376 100644 --- a/src/hotspot/share/opto/macro.cpp +++ b/src/hotspot/share/opto/macro.cpp @@ -2577,11 +2577,11 @@ void PhaseMacroExpand::eliminate_opaque_looplimit_macro_nodes() { // a CMoveL construct now. At least until here, the type could be computed // precisely. CMoveL is not so smart, but we can give it at least the best // type we know abouot n now. - Node* repl = MaxNode::signed_max(n->in(1), n->in(2), _igvn.type(n), _igvn); + Node* repl = MinMaxNode::signed_max(n->in(1), n->in(2), _igvn.type(n), _igvn); _igvn.replace_node(n, repl); success = true; } else if (n->Opcode() == Op_MinL) { - Node* repl = MaxNode::signed_min(n->in(1), n->in(2), _igvn.type(n), _igvn); + Node* repl = MinMaxNode::signed_min(n->in(1), n->in(2), _igvn.type(n), _igvn); _igvn.replace_node(n, repl); success = true; } diff --git a/src/hotspot/share/opto/movenode.cpp b/src/hotspot/share/opto/movenode.cpp index 66db1df339b..6b6becb434f 100644 --- a/src/hotspot/share/opto/movenode.cpp +++ b/src/hotspot/share/opto/movenode.cpp @@ -271,9 +271,9 @@ Node* CMoveNode::Ideal_minmax(PhaseGVN* phase, CMoveNode* cmove) { // Create the Min/Max node based on the type and kind if (cmp_op == Op_CmpL) { - return MaxNode::build_min_max_long(phase, cmp_l, cmp_r, is_max); + return MinMaxNode::build_min_max_long(phase, cmp_l, cmp_r, is_max); } else { - return MaxNode::build_min_max_int(cmp_l, cmp_r, is_max); + return MinMaxNode::build_min_max_int(cmp_l, cmp_r, is_max); } } diff --git a/src/hotspot/share/opto/node.hpp b/src/hotspot/share/opto/node.hpp index 0adb2072100..f1d9785a746 100644 --- a/src/hotspot/share/opto/node.hpp +++ b/src/hotspot/share/opto/node.hpp @@ -130,6 +130,7 @@ class MemBarNode; class MemBarStoreStoreNode; class MemNode; class MergeMemNode; +class MinMaxNode; class MoveNode; class MulNode; class MultiNode; @@ -809,6 +810,7 @@ public: DEFINE_CLASS_ID(AddP, Node, 9) DEFINE_CLASS_ID(BoxLock, Node, 10) DEFINE_CLASS_ID(Add, Node, 11) + DEFINE_CLASS_ID(MinMax, Add, 0) DEFINE_CLASS_ID(Mul, Node, 12) DEFINE_CLASS_ID(ClearArray, Node, 14) DEFINE_CLASS_ID(Halt, Node, 15) @@ -986,6 +988,7 @@ public: DEFINE_CLASS_QUERY(MemBar) DEFINE_CLASS_QUERY(MemBarStoreStore) DEFINE_CLASS_QUERY(MergeMem) + DEFINE_CLASS_QUERY(MinMax) DEFINE_CLASS_QUERY(Move) DEFINE_CLASS_QUERY(Mul) DEFINE_CLASS_QUERY(Multi) diff --git a/src/hotspot/share/opto/phaseX.cpp b/src/hotspot/share/opto/phaseX.cpp index c4bdc5e8903..52badca8050 100644 --- a/src/hotspot/share/opto/phaseX.cpp +++ b/src/hotspot/share/opto/phaseX.cpp @@ -2633,6 +2633,15 @@ void PhaseIterGVN::add_users_of_use_to_worklist(Node* n, Node* use, Unique_Node_ } } } + // Check for Max/Min(A, Max/Min(B, C)) where A == B or A == C + if (use->is_MinMax()) { + for (DUIterator_Fast i2max, i2 = use->fast_outs(i2max); i2 < i2max; i2++) { + Node* u = use->fast_out(i2); + if (u->Opcode() == use->Opcode()) { + worklist.push(u); + } + } + } auto enqueue_init_mem_projs = [&](ProjNode* proj) { add_users_to_worklist0(proj, worklist); }; diff --git a/src/hotspot/share/opto/vectorization.cpp b/src/hotspot/share/opto/vectorization.cpp index 1755b0453eb..8e0ca927a16 100644 --- a/src/hotspot/share/opto/vectorization.cpp +++ b/src/hotspot/share/opto/vectorization.cpp @@ -1122,7 +1122,7 @@ Node* make_last(Node* initL, jint stride, Node* limitL, PhaseIdealLoop* phase) { Node* last = new AddLNode(initL, k_mul_stride); // Make sure that the last does not lie "before" init. - Node* last_clamped = MaxNode::build_min_max_long(&igvn, initL, last, stride > 0); + Node* last_clamped = MinMaxNode::build_min_max_long(&igvn, initL, last, stride > 0); phase->register_new_node_with_ctrl_of(diffL, initL); phase->register_new_node_with_ctrl_of(diffL_m1, initL); diff --git a/test/hotspot/jtreg/compiler/igvn/TestMinMaxIdentity.java b/test/hotspot/jtreg/compiler/igvn/TestMinMaxIdentity.java new file mode 100644 index 00000000000..d358359ff14 --- /dev/null +++ b/test/hotspot/jtreg/compiler/igvn/TestMinMaxIdentity.java @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2025 IBM Corporation. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8373134 + * @summary Verify that min/max add identity optimizations get applied correctly + * @modules java.base/jdk.internal.misc + * @modules jdk.incubator.vector + * @library /test/lib / + * @run driver ${test.main.class} + */ + +package compiler.igvn; + +import compiler.lib.compile_framework.CompileFramework; +import compiler.lib.template_framework.Template; +import compiler.lib.template_framework.TemplateToken; +import compiler.lib.template_framework.library.CodeGenerationDataNameType; +import compiler.lib.template_framework.library.PrimitiveType; +import compiler.lib.template_framework.library.TestFrameworkClass; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.stream.Stream; + +import static compiler.lib.template_framework.Template.let; +import static compiler.lib.template_framework.Template.scope; + +public class TestMinMaxIdentity { + public static void main(String[] args) { + // Create a new CompileFramework instance. + CompileFramework comp = new CompileFramework(); + + // Add a java source file. + comp.addJavaSourceCode("compiler.igvn.templated.MinMaxIdentity", generate(comp)); + + // Compile the source file. + comp.compile("--add-modules=jdk.incubator.vector"); + + comp.invoke("compiler.igvn.templated.MinMaxIdentity", "main", new Object[] {new String[] { + "--add-modules=jdk.incubator.vector", + "--add-opens", "jdk.incubator.vector/jdk.incubator.vector=ALL-UNNAMED" + }}); + } + + private static String generate(CompileFramework comp) { + // Create a list to collect all tests. + List testTemplateTokens = new ArrayList<>(); + + Stream.of(MinMaxOp.values()) + .flatMap(MinMaxOp::generate) + .forEach(testTemplateTokens::add); + + Stream.of(Fp16MinMaxOp.values()) + .flatMap(Fp16MinMaxOp::generate) + .forEach(testTemplateTokens::add); + + // Create the test class, which runs all testTemplateTokens. + return TestFrameworkClass.render( + // package and class name. + "compiler.igvn.templated", "MinMaxIdentity", + // List of imports. + Set.of("jdk.incubator.vector.Float16"), + // classpath, so the Test VM has access to the compiled class files. + comp.getEscapedClassPathOfCompiledClasses(), + // The list of tests. + testTemplateTokens); + } + + enum MinMaxOp { + MIN_D("min", CodeGenerationDataNameType.doubles()), + MAX_D("max", CodeGenerationDataNameType.doubles()), + MIN_F("min", CodeGenerationDataNameType.floats()), + MAX_F("max", CodeGenerationDataNameType.floats()), + MIN_I("min", CodeGenerationDataNameType.ints()), + MAX_I("max", CodeGenerationDataNameType.ints()), + MIN_L("min", CodeGenerationDataNameType.longs()), + MAX_L("max", CodeGenerationDataNameType.longs()); + + final String functionName; + final PrimitiveType type; + + MinMaxOp(String functionName, PrimitiveType type) { + this.functionName = functionName; + this.type = type; + } + + Stream generate() { + return Stream.of(template("a", "b"), template("b", "a")). + map(Template.ZeroArgs::asToken); + } + + private Template.ZeroArgs template(String arg1, String arg2) { + return Template.make(() -> scope( + let("boxedTypeName", type.boxedTypeName()), + let("op", name()), + let("type", type.name()), + let("functionName", functionName), + let("arg1", arg1), + let("arg2", arg2), + """ + @Test + @IR(counts = {IRNode.#op, "= 1"}, + phase = CompilePhase.BEFORE_MACRO_EXPANSION) + @Arguments(values = {Argument.NUMBER_42, Argument.NUMBER_42}) + public #type $test(#type #arg1, #type #arg2) { + int i; + for (i = -10; i < 1; i++) { + } + #type c = a * i; + return #boxedTypeName.#functionName(a, #boxedTypeName.#functionName(b, c)); + } + """ + )); + } + } + + enum Fp16MinMaxOp { + MAX_HF("max"), + MIN_HF("min"); + + final String functionName; + + Fp16MinMaxOp(String functionName) { + this.functionName = functionName; + } + + Stream generate() { + return Stream.of(template("a", "b"), template("b", "a")). + map(Template.ZeroArgs::asToken); + } + + private Template.ZeroArgs template(String arg1, String arg2) { + return Template.make(() -> scope( + let("op", name()), + let("functionName", functionName), + let("arg1", arg1), + let("arg2", arg2), + """ + @Setup + private static Object[] $setup() { + return new Object[] {Float16.valueOf(42), Float16.valueOf(42)}; + } + + @Test + @IR(counts = {IRNode.#op, "= 1"}, + phase = CompilePhase.BEFORE_MACRO_EXPANSION, + applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"}) + @IR(counts = {IRNode.#op, "= 1"}, + phase = CompilePhase.BEFORE_MACRO_EXPANSION, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) + @Arguments(setup = "$setup") + public Float16 $test(Float16 #arg1, Float16 #arg2) { + int i; + for (i = -10; i < 1; i++) { + } + Float16 c = Float16.multiply(a, Float16.valueOf(i)); + return Float16.#functionName(a, Float16.#functionName(b, c)); + } + """ + )); + } + } +}