mirror of
https://github.com/openjdk/jdk.git
synced 2026-04-23 05:10:57 +00:00
8354242: VectorAPI: combine vector not operation with compare
Reviewed-by: epeter, jbhateja, xgong
This commit is contained in:
parent
c2c44a061a
commit
45cc515f45
@ -1216,6 +1216,9 @@ bool Node::has_special_unique_user() const {
|
||||
} else if ((is_IfFalse() || is_IfTrue()) && n->is_If()) {
|
||||
// See IfNode::fold_compares
|
||||
return true;
|
||||
} else if (n->Opcode() == Op_XorV || n->Opcode() == Op_XorVMask) {
|
||||
// Condition for XorVMask(VectorMaskCmp(x,y,cond), MaskAll(true)) ==> VectorMaskCmp(x,y,ncond)
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -328,7 +328,9 @@ struct BoolTest {
|
||||
// a simple char array where each element is the ASCII version of a 'mask'
|
||||
// enum from above.
|
||||
mask commute( ) const { return mask("032147658"[_test]-'0'); }
|
||||
mask negate( ) const { return mask(_test^4); }
|
||||
mask negate( ) const { return negate_mask(_test); }
|
||||
// Return the negative mask for the given mask, for both signed and unsigned comparison.
|
||||
static mask negate_mask(mask btm) { return mask(btm ^ 4); }
|
||||
bool is_canonical( ) const { return (_test == BoolTest::ne || _test == BoolTest::lt || _test == BoolTest::le || _test == BoolTest::overflow); }
|
||||
bool is_less( ) const { return _test == BoolTest::lt || _test == BoolTest::le; }
|
||||
bool is_greater( ) const { return _test == BoolTest::gt || _test == BoolTest::ge; }
|
||||
|
||||
@ -2268,6 +2268,99 @@ Node* OrVNode::Identity(PhaseGVN* phase) {
|
||||
return redundant_logical_identity(this);
|
||||
}
|
||||
|
||||
// Returns whether (XorV (VectorMaskCmp) -1) can be optimized by negating the
|
||||
// comparison operation.
|
||||
bool VectorMaskCmpNode::predicate_can_be_negated() {
|
||||
switch (_predicate) {
|
||||
case BoolTest::eq:
|
||||
case BoolTest::ne:
|
||||
// eq and ne also apply to floating-point special values like NaN and infinities.
|
||||
return true;
|
||||
case BoolTest::le:
|
||||
case BoolTest::ge:
|
||||
case BoolTest::lt:
|
||||
case BoolTest::gt:
|
||||
case BoolTest::ule:
|
||||
case BoolTest::uge:
|
||||
case BoolTest::ult:
|
||||
case BoolTest::ugt: {
|
||||
BasicType bt = vect_type()->element_basic_type();
|
||||
// For float and double, we don't know if either comparison operand is a
|
||||
// NaN, NaN {le|ge|lt|gt} anything is false, resulting in inconsistent
|
||||
// results before and after negation.
|
||||
return is_integral_type(bt);
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// This function transforms the following patterns:
|
||||
//
|
||||
// For integer types:
|
||||
// (XorV (VectorMaskCmp src1 src2 cond) (Replicate -1))
|
||||
// => (VectorMaskCmp src1 src2 ncond)
|
||||
// (XorVMask (VectorMaskCmp src1 src2 cond) (MaskAll m1))
|
||||
// => (VectorMaskCmp src1 src2 ncond)
|
||||
// (XorV (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (Replicate -1))
|
||||
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
|
||||
// (XorVMask (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (MaskAll m1))
|
||||
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
|
||||
// cond can be eq, ne, le, ge, lt, gt, ule, uge, ult and ugt.
|
||||
// ncond is the negative comparison of cond.
|
||||
//
|
||||
// For float and double types:
|
||||
// (XorV (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (Replicate -1))
|
||||
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
|
||||
// (XorVMask (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (MaskAll m1))
|
||||
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
|
||||
// cond can be eq or ne.
|
||||
Node* XorVNode::Ideal_XorV_VectorMaskCmp(PhaseGVN* phase, bool can_reshape) {
|
||||
Node* in1 = in(1);
|
||||
Node* in2 = in(2);
|
||||
// Transformations for predicated vectors are not supported for now.
|
||||
if (is_predicated_vector() ||
|
||||
in1->is_predicated_vector() ||
|
||||
in2->is_predicated_vector()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// XorV/XorVMask is commutative, swap VectorMaskCmp/VectorMaskCast to in1.
|
||||
if (VectorNode::is_all_ones_vector(in1)) {
|
||||
swap(in1, in2);
|
||||
}
|
||||
|
||||
bool with_vector_mask_cast = false;
|
||||
// Required conditions:
|
||||
// 1. VectorMaskCast and VectorMaskCmp should only have a single use,
|
||||
// otherwise the optimization may be unprofitable.
|
||||
// 2. The predicate of VectorMaskCmp should be negatable.
|
||||
// 3. The second input should be an all true vector mask.
|
||||
if (in1->Opcode() == Op_VectorMaskCast) {
|
||||
if (in1->outcnt() != 1) {
|
||||
return nullptr;
|
||||
}
|
||||
with_vector_mask_cast = true;
|
||||
in1 = in1->in(1);
|
||||
}
|
||||
if (in1->Opcode() != Op_VectorMaskCmp ||
|
||||
in1->outcnt() != 1 ||
|
||||
!in1->as_VectorMaskCmp()->predicate_can_be_negated() ||
|
||||
!VectorNode::is_all_ones_vector(in2)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
BoolTest::mask neg_cond = BoolTest::negate_mask((in1->as_VectorMaskCmp())->get_predicate());
|
||||
ConINode* predicate_node = phase->intcon(neg_cond);
|
||||
const TypeVect* vt = in1->as_Vector()->vect_type();
|
||||
Node* res = new VectorMaskCmpNode(neg_cond, in1->in(1), in1->in(2), predicate_node, vt);
|
||||
if (with_vector_mask_cast) {
|
||||
// We optimized out a VectorMaskCast, regenerate one to ensure type correctness.
|
||||
res = new VectorMaskCastNode(phase->transform(res), vect_type());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Node* XorVNode::Ideal(PhaseGVN* phase, bool can_reshape) {
|
||||
// (XorV src src) => (Replicate zero)
|
||||
// (XorVMask src src) => (MaskAll zero)
|
||||
@ -2281,6 +2374,11 @@ Node* XorVNode::Ideal(PhaseGVN* phase, bool can_reshape) {
|
||||
Node* zero = phase->transform(phase->zerocon(bt));
|
||||
return VectorNode::scalar2vector(zero, length(), bt, bottom_type()->isa_vectmask() != nullptr);
|
||||
}
|
||||
|
||||
Node* res = Ideal_XorV_VectorMaskCmp(phase, can_reshape);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
return VectorNode::Ideal(phase, can_reshape);
|
||||
}
|
||||
|
||||
|
||||
@ -1013,6 +1013,7 @@ class XorVNode : public VectorNode {
|
||||
XorVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
|
||||
virtual int Opcode() const;
|
||||
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
|
||||
Node* Ideal_XorV_VectorMaskCmp(PhaseGVN* phase, bool can_reshape);
|
||||
};
|
||||
|
||||
//------------------------------XorReductionVNode--------------------------------------
|
||||
@ -1676,6 +1677,7 @@ class VectorMaskCmpNode : public VectorNode {
|
||||
virtual bool cmp( const Node &n ) const {
|
||||
return VectorNode::cmp(n) && _predicate == ((VectorMaskCmpNode&)n)._predicate;
|
||||
}
|
||||
bool predicate_can_be_negated();
|
||||
BoolTest::mask get_predicate() { return _predicate; }
|
||||
#ifndef PRODUCT
|
||||
virtual void dump_spec(outputStream *st) const;
|
||||
|
||||
@ -2295,6 +2295,11 @@ public class IRNode {
|
||||
vectorNode(VECTOR_MASK_CMP_D, "VectorMaskCmp", TYPE_DOUBLE);
|
||||
}
|
||||
|
||||
public static final String VECTOR_MASK_CMP = PREFIX + "VECTOR_MASK_CMP" + POSTFIX;
|
||||
static {
|
||||
beforeMatchingNameRegex(VECTOR_MASK_CMP, "VectorMaskCmp");
|
||||
}
|
||||
|
||||
public static final String VECTOR_CAST_B2S = VECTOR_PREFIX + "VECTOR_CAST_B2S" + POSTFIX;
|
||||
static {
|
||||
vectorNode(VECTOR_CAST_B2S, "VectorCastB2X", TYPE_SHORT);
|
||||
@ -2705,6 +2710,11 @@ public class IRNode {
|
||||
vectorNode(XOR_VL, "XorV", TYPE_LONG);
|
||||
}
|
||||
|
||||
public static final String XOR_V = PREFIX + "XOR_V" + POSTFIX;
|
||||
static {
|
||||
beforeMatchingNameRegex(XOR_V, "XorV");
|
||||
}
|
||||
|
||||
public static final String XOR_V_MASK = PREFIX + "XOR_V_MASK" + POSTFIX;
|
||||
static {
|
||||
beforeMatchingNameRegex(XOR_V_MASK, "XorVMask");
|
||||
|
||||
1299
test/hotspot/jtreg/compiler/vectorapi/VectorMaskCompareNotTest.java
Normal file
1299
test/hotspot/jtreg/compiler/vectorapi/VectorMaskCompareNotTest.java
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,220 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
|
||||
package org.openjdk.bench.jdk.incubator.vector;
|
||||
|
||||
import org.openjdk.jmh.annotations.*;
|
||||
import org.openjdk.jmh.infra.*;
|
||||
|
||||
import jdk.incubator.vector.*;
|
||||
import java.lang.invoke.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.Random;
|
||||
|
||||
@BenchmarkMode(Mode.Throughput)
|
||||
@OutputTimeUnit(TimeUnit.SECONDS)
|
||||
@State(Scope.Thread)
|
||||
@Warmup(iterations = 5, time = 1)
|
||||
@Measurement(iterations = 5, time = 1)
|
||||
@Fork(value = 2, jvmArgs = { "--add-modules=jdk.incubator.vector" })
|
||||
public abstract class MaskCompareNotBenchmark {
|
||||
@Param({"4096"})
|
||||
protected int ARRAYLEN;
|
||||
|
||||
// Abstract method to get comparison operator from subclasses
|
||||
protected abstract String getComparisonOperatorName();
|
||||
|
||||
// To get compile-time constants for comparison operation
|
||||
static final MutableCallSite MUTABLE_COMPARISON_CONSTANT = new MutableCallSite(MethodType.methodType(VectorOperators.Comparison.class));
|
||||
static final MethodHandle MUTABLE_COMPARISON_CONSTANT_HANDLE = MUTABLE_COMPARISON_CONSTANT.dynamicInvoker();
|
||||
|
||||
private static Random r = new Random();
|
||||
|
||||
protected static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_MAX;
|
||||
protected static final VectorSpecies<Short> S_SPECIES = ShortVector.SPECIES_MAX;
|
||||
protected static final VectorSpecies<Integer> I_SPECIES = IntVector.SPECIES_MAX;
|
||||
protected static final VectorSpecies<Long> L_SPECIES = LongVector.SPECIES_MAX;
|
||||
protected static final VectorSpecies<Float> F_SPECIES = FloatVector.SPECIES_MAX;
|
||||
protected static final VectorSpecies<Double> D_SPECIES = DoubleVector.SPECIES_MAX;
|
||||
|
||||
protected boolean[] mr;
|
||||
protected byte[] ba;
|
||||
protected byte[] bb;
|
||||
protected short[] sa;
|
||||
protected short[] sb;
|
||||
protected int[] ia;
|
||||
protected int[] ib;
|
||||
protected long[] la;
|
||||
protected long[] lb;
|
||||
protected float[] fa;
|
||||
protected float[] fb;
|
||||
protected double[] da;
|
||||
protected double[] db;
|
||||
|
||||
@Setup
|
||||
public void init() throws Throwable {
|
||||
mr = new boolean[ARRAYLEN];
|
||||
ba = new byte[ARRAYLEN];
|
||||
bb = new byte[ARRAYLEN];
|
||||
sa = new short[ARRAYLEN];
|
||||
sb = new short[ARRAYLEN];
|
||||
ia = new int[ARRAYLEN];
|
||||
ib = new int[ARRAYLEN];
|
||||
la = new long[ARRAYLEN];
|
||||
lb = new long[ARRAYLEN];
|
||||
fa = new float[ARRAYLEN];
|
||||
fb = new float[ARRAYLEN];
|
||||
da = new double[ARRAYLEN];
|
||||
db = new double[ARRAYLEN];
|
||||
|
||||
for (int i = 0; i < ARRAYLEN; i++) {
|
||||
mr[i] = r.nextBoolean();
|
||||
ba[i] = (byte) r.nextInt();
|
||||
bb[i] = (byte) r.nextInt();
|
||||
sa[i] = (short) r.nextInt();
|
||||
sb[i] = (short) r.nextInt();
|
||||
ia[i] = r.nextInt();
|
||||
ib[i] = r.nextInt();
|
||||
la[i] = r.nextLong();
|
||||
lb[i] = r.nextLong();
|
||||
fa[i] = r.nextFloat();
|
||||
fb[i] = r.nextFloat();
|
||||
da[i] = r.nextDouble();
|
||||
db[i] = r.nextDouble();
|
||||
}
|
||||
|
||||
VectorOperators.Comparison comparisonOp = getComparisonOperator(getComparisonOperatorName());
|
||||
MethodHandle constant = MethodHandles.constant(VectorOperators.Comparison.class, comparisonOp);
|
||||
MUTABLE_COMPARISON_CONSTANT.setTarget(constant);
|
||||
}
|
||||
|
||||
@CompilerControl(CompilerControl.Mode.INLINE)
|
||||
private static VectorOperators.Comparison getComparisonOperator(String op) {
|
||||
switch (op) {
|
||||
case "EQ": return VectorOperators.EQ;
|
||||
case "NE": return VectorOperators.NE;
|
||||
case "LT": return VectorOperators.LT;
|
||||
case "LE": return VectorOperators.LE;
|
||||
case "GT": return VectorOperators.GT;
|
||||
case "GE": return VectorOperators.GE;
|
||||
case "ULT": return VectorOperators.ULT;
|
||||
case "ULE": return VectorOperators.ULE;
|
||||
case "UGT": return VectorOperators.UGT;
|
||||
case "UGE": return VectorOperators.UGE;
|
||||
default: throw new IllegalArgumentException("Unknown comparison operator: " + op);
|
||||
}
|
||||
}
|
||||
|
||||
@CompilerControl(CompilerControl.Mode.INLINE)
|
||||
protected VectorOperators.Comparison comparison_con() throws Throwable {
|
||||
return (VectorOperators.Comparison) MUTABLE_COMPARISON_CONSTANT_HANDLE.invokeExact();
|
||||
}
|
||||
|
||||
// Subclasses with different comparison operators
|
||||
public static class IntegerComparisons extends MaskCompareNotBenchmark {
|
||||
@Param({"EQ", "NE", "LT", "LE", "GT", "GE", "ULT", "ULE", "UGT", "UGE"})
|
||||
public String COMPARISON_OP;
|
||||
|
||||
@Override
|
||||
protected String getComparisonOperatorName() {
|
||||
return COMPARISON_OP;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testCompareMaskNotByte() throws Throwable {
|
||||
VectorOperators.Comparison op = comparison_con();
|
||||
ByteVector bv = ByteVector.fromArray(B_SPECIES, bb, 0);
|
||||
for (int j = 0; j < ARRAYLEN; j += B_SPECIES.length()) {
|
||||
ByteVector av = ByteVector.fromArray(B_SPECIES, ba, j);
|
||||
VectorMask<Byte> m = av.compare(op, bv).not();
|
||||
m.intoArray(mr, j);
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testCompareMaskNotShort() throws Throwable {
|
||||
VectorOperators.Comparison op = comparison_con();
|
||||
ShortVector bv = ShortVector.fromArray(S_SPECIES, sb, 0);
|
||||
for (int j = 0; j < ARRAYLEN; j += S_SPECIES.length()) {
|
||||
ShortVector av = ShortVector.fromArray(S_SPECIES, sa, j);
|
||||
VectorMask<Short> m = av.compare(op, bv).not();
|
||||
m.intoArray(mr, j);
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testCompareMaskNotInt() throws Throwable {
|
||||
VectorOperators.Comparison op = comparison_con();
|
||||
IntVector bv = IntVector.fromArray(I_SPECIES, ib, 0);
|
||||
for (int j = 0; j < ARRAYLEN; j += I_SPECIES.length()) {
|
||||
IntVector av = IntVector.fromArray(I_SPECIES, ia, j);
|
||||
VectorMask<Integer> m = av.compare(op, bv).not();
|
||||
m.intoArray(mr, j);
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testCompareMaskNotLong() throws Throwable {
|
||||
VectorOperators.Comparison op = comparison_con();
|
||||
LongVector bv = LongVector.fromArray(L_SPECIES, lb, 0);
|
||||
for (int j = 0; j < ARRAYLEN; j += L_SPECIES.length()) {
|
||||
LongVector av = LongVector.fromArray(L_SPECIES, la, j);
|
||||
VectorMask<Long> m = av.compare(op, bv).not();
|
||||
m.intoArray(mr, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class FloatingPointComparisons extends MaskCompareNotBenchmark {
|
||||
// "ULT", "ULE", "UGT", "UGE" are not supported for floating point types
|
||||
@Param({"EQ", "NE", "LT", "LE", "GT", "GE"})
|
||||
public String COMPARISON_OP;
|
||||
|
||||
@Override
|
||||
protected String getComparisonOperatorName() {
|
||||
return COMPARISON_OP;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testCompareMaskNotFloat() throws Throwable {
|
||||
VectorOperators.Comparison op = comparison_con();
|
||||
FloatVector bv = FloatVector.fromArray(F_SPECIES, fb, 0);
|
||||
for (int j = 0; j < ARRAYLEN; j += F_SPECIES.length()) {
|
||||
FloatVector av = FloatVector.fromArray(F_SPECIES, fa, j);
|
||||
VectorMask<Float> m = av.compare(op, bv).not();
|
||||
m.intoArray(mr, j);
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public void testCompareMaskNotDouble() throws Throwable {
|
||||
VectorOperators.Comparison op = comparison_con();
|
||||
DoubleVector bv = DoubleVector.fromArray(D_SPECIES, db, 0);
|
||||
for (int j = 0; j < ARRAYLEN; j += D_SPECIES.length()) {
|
||||
DoubleVector av = DoubleVector.fromArray(D_SPECIES, da, j);
|
||||
VectorMask<Double> m = av.compare(op, bv).not();
|
||||
m.intoArray(mr, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user