8354242: VectorAPI: combine vector not operation with compare

Reviewed-by: epeter, jbhateja, xgong
This commit is contained in:
erifan 2025-09-17 07:32:19 +00:00 committed by Xiaohong Gong
parent c2c44a061a
commit 45cc515f45
7 changed files with 1635 additions and 1 deletions

View File

@ -1216,6 +1216,9 @@ bool Node::has_special_unique_user() const {
} else if ((is_IfFalse() || is_IfTrue()) && n->is_If()) {
// See IfNode::fold_compares
return true;
} else if (n->Opcode() == Op_XorV || n->Opcode() == Op_XorVMask) {
// Condition for XorVMask(VectorMaskCmp(x,y,cond), MaskAll(true)) ==> VectorMaskCmp(x,y,ncond)
return true;
} else {
return false;
}

View File

@ -328,7 +328,9 @@ struct BoolTest {
// a simple char array where each element is the ASCII version of a 'mask'
// enum from above.
mask commute( ) const { return mask("032147658"[_test]-'0'); }
mask negate( ) const { return mask(_test^4); }
mask negate( ) const { return negate_mask(_test); }
// Return the negative mask for the given mask, for both signed and unsigned comparison.
static mask negate_mask(mask btm) { return mask(btm ^ 4); }
bool is_canonical( ) const { return (_test == BoolTest::ne || _test == BoolTest::lt || _test == BoolTest::le || _test == BoolTest::overflow); }
bool is_less( ) const { return _test == BoolTest::lt || _test == BoolTest::le; }
bool is_greater( ) const { return _test == BoolTest::gt || _test == BoolTest::ge; }

View File

@ -2268,6 +2268,99 @@ Node* OrVNode::Identity(PhaseGVN* phase) {
return redundant_logical_identity(this);
}
// Returns whether (XorV (VectorMaskCmp) -1) can be optimized by negating the
// comparison operation.
bool VectorMaskCmpNode::predicate_can_be_negated() {
switch (_predicate) {
case BoolTest::eq:
case BoolTest::ne:
// eq and ne also apply to floating-point special values like NaN and infinities.
return true;
case BoolTest::le:
case BoolTest::ge:
case BoolTest::lt:
case BoolTest::gt:
case BoolTest::ule:
case BoolTest::uge:
case BoolTest::ult:
case BoolTest::ugt: {
BasicType bt = vect_type()->element_basic_type();
// For float and double, we don't know if either comparison operand is a
// NaN, NaN {le|ge|lt|gt} anything is false, resulting in inconsistent
// results before and after negation.
return is_integral_type(bt);
}
default:
return false;
}
}
// This function transforms the following patterns:
//
// For integer types:
// (XorV (VectorMaskCmp src1 src2 cond) (Replicate -1))
// => (VectorMaskCmp src1 src2 ncond)
// (XorVMask (VectorMaskCmp src1 src2 cond) (MaskAll m1))
// => (VectorMaskCmp src1 src2 ncond)
// (XorV (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (Replicate -1))
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
// (XorVMask (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (MaskAll m1))
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
// cond can be eq, ne, le, ge, lt, gt, ule, uge, ult and ugt.
// ncond is the negative comparison of cond.
//
// For float and double types:
// (XorV (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (Replicate -1))
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
// (XorVMask (VectorMaskCast (VectorMaskCmp src1 src2 cond)) (MaskAll m1))
// => (VectorMaskCast (VectorMaskCmp src1 src2 ncond))
// cond can be eq or ne.
Node* XorVNode::Ideal_XorV_VectorMaskCmp(PhaseGVN* phase, bool can_reshape) {
Node* in1 = in(1);
Node* in2 = in(2);
// Transformations for predicated vectors are not supported for now.
if (is_predicated_vector() ||
in1->is_predicated_vector() ||
in2->is_predicated_vector()) {
return nullptr;
}
// XorV/XorVMask is commutative, swap VectorMaskCmp/VectorMaskCast to in1.
if (VectorNode::is_all_ones_vector(in1)) {
swap(in1, in2);
}
bool with_vector_mask_cast = false;
// Required conditions:
// 1. VectorMaskCast and VectorMaskCmp should only have a single use,
// otherwise the optimization may be unprofitable.
// 2. The predicate of VectorMaskCmp should be negatable.
// 3. The second input should be an all true vector mask.
if (in1->Opcode() == Op_VectorMaskCast) {
if (in1->outcnt() != 1) {
return nullptr;
}
with_vector_mask_cast = true;
in1 = in1->in(1);
}
if (in1->Opcode() != Op_VectorMaskCmp ||
in1->outcnt() != 1 ||
!in1->as_VectorMaskCmp()->predicate_can_be_negated() ||
!VectorNode::is_all_ones_vector(in2)) {
return nullptr;
}
BoolTest::mask neg_cond = BoolTest::negate_mask((in1->as_VectorMaskCmp())->get_predicate());
ConINode* predicate_node = phase->intcon(neg_cond);
const TypeVect* vt = in1->as_Vector()->vect_type();
Node* res = new VectorMaskCmpNode(neg_cond, in1->in(1), in1->in(2), predicate_node, vt);
if (with_vector_mask_cast) {
// We optimized out a VectorMaskCast, regenerate one to ensure type correctness.
res = new VectorMaskCastNode(phase->transform(res), vect_type());
}
return res;
}
Node* XorVNode::Ideal(PhaseGVN* phase, bool can_reshape) {
// (XorV src src) => (Replicate zero)
// (XorVMask src src) => (MaskAll zero)
@ -2281,6 +2374,11 @@ Node* XorVNode::Ideal(PhaseGVN* phase, bool can_reshape) {
Node* zero = phase->transform(phase->zerocon(bt));
return VectorNode::scalar2vector(zero, length(), bt, bottom_type()->isa_vectmask() != nullptr);
}
Node* res = Ideal_XorV_VectorMaskCmp(phase, can_reshape);
if (res != nullptr) {
return res;
}
return VectorNode::Ideal(phase, can_reshape);
}

View File

@ -1013,6 +1013,7 @@ class XorVNode : public VectorNode {
XorVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1,in2,vt) {}
virtual int Opcode() const;
virtual Node* Ideal(PhaseGVN* phase, bool can_reshape);
Node* Ideal_XorV_VectorMaskCmp(PhaseGVN* phase, bool can_reshape);
};
//------------------------------XorReductionVNode--------------------------------------
@ -1676,6 +1677,7 @@ class VectorMaskCmpNode : public VectorNode {
virtual bool cmp( const Node &n ) const {
return VectorNode::cmp(n) && _predicate == ((VectorMaskCmpNode&)n)._predicate;
}
bool predicate_can_be_negated();
BoolTest::mask get_predicate() { return _predicate; }
#ifndef PRODUCT
virtual void dump_spec(outputStream *st) const;

View File

@ -2295,6 +2295,11 @@ public class IRNode {
vectorNode(VECTOR_MASK_CMP_D, "VectorMaskCmp", TYPE_DOUBLE);
}
public static final String VECTOR_MASK_CMP = PREFIX + "VECTOR_MASK_CMP" + POSTFIX;
static {
beforeMatchingNameRegex(VECTOR_MASK_CMP, "VectorMaskCmp");
}
public static final String VECTOR_CAST_B2S = VECTOR_PREFIX + "VECTOR_CAST_B2S" + POSTFIX;
static {
vectorNode(VECTOR_CAST_B2S, "VectorCastB2X", TYPE_SHORT);
@ -2705,6 +2710,11 @@ public class IRNode {
vectorNode(XOR_VL, "XorV", TYPE_LONG);
}
public static final String XOR_V = PREFIX + "XOR_V" + POSTFIX;
static {
beforeMatchingNameRegex(XOR_V, "XorV");
}
public static final String XOR_V_MASK = PREFIX + "XOR_V_MASK" + POSTFIX;
static {
beforeMatchingNameRegex(XOR_V_MASK, "XorVMask");

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,220 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.jdk.incubator.vector;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.*;
import jdk.incubator.vector.*;
import java.lang.invoke.*;
import java.util.concurrent.TimeUnit;
import java.util.Random;
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Thread)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 2, jvmArgs = { "--add-modules=jdk.incubator.vector" })
public abstract class MaskCompareNotBenchmark {
@Param({"4096"})
protected int ARRAYLEN;
// Abstract method to get comparison operator from subclasses
protected abstract String getComparisonOperatorName();
// To get compile-time constants for comparison operation
static final MutableCallSite MUTABLE_COMPARISON_CONSTANT = new MutableCallSite(MethodType.methodType(VectorOperators.Comparison.class));
static final MethodHandle MUTABLE_COMPARISON_CONSTANT_HANDLE = MUTABLE_COMPARISON_CONSTANT.dynamicInvoker();
private static Random r = new Random();
protected static final VectorSpecies<Byte> B_SPECIES = ByteVector.SPECIES_MAX;
protected static final VectorSpecies<Short> S_SPECIES = ShortVector.SPECIES_MAX;
protected static final VectorSpecies<Integer> I_SPECIES = IntVector.SPECIES_MAX;
protected static final VectorSpecies<Long> L_SPECIES = LongVector.SPECIES_MAX;
protected static final VectorSpecies<Float> F_SPECIES = FloatVector.SPECIES_MAX;
protected static final VectorSpecies<Double> D_SPECIES = DoubleVector.SPECIES_MAX;
protected boolean[] mr;
protected byte[] ba;
protected byte[] bb;
protected short[] sa;
protected short[] sb;
protected int[] ia;
protected int[] ib;
protected long[] la;
protected long[] lb;
protected float[] fa;
protected float[] fb;
protected double[] da;
protected double[] db;
@Setup
public void init() throws Throwable {
mr = new boolean[ARRAYLEN];
ba = new byte[ARRAYLEN];
bb = new byte[ARRAYLEN];
sa = new short[ARRAYLEN];
sb = new short[ARRAYLEN];
ia = new int[ARRAYLEN];
ib = new int[ARRAYLEN];
la = new long[ARRAYLEN];
lb = new long[ARRAYLEN];
fa = new float[ARRAYLEN];
fb = new float[ARRAYLEN];
da = new double[ARRAYLEN];
db = new double[ARRAYLEN];
for (int i = 0; i < ARRAYLEN; i++) {
mr[i] = r.nextBoolean();
ba[i] = (byte) r.nextInt();
bb[i] = (byte) r.nextInt();
sa[i] = (short) r.nextInt();
sb[i] = (short) r.nextInt();
ia[i] = r.nextInt();
ib[i] = r.nextInt();
la[i] = r.nextLong();
lb[i] = r.nextLong();
fa[i] = r.nextFloat();
fb[i] = r.nextFloat();
da[i] = r.nextDouble();
db[i] = r.nextDouble();
}
VectorOperators.Comparison comparisonOp = getComparisonOperator(getComparisonOperatorName());
MethodHandle constant = MethodHandles.constant(VectorOperators.Comparison.class, comparisonOp);
MUTABLE_COMPARISON_CONSTANT.setTarget(constant);
}
@CompilerControl(CompilerControl.Mode.INLINE)
private static VectorOperators.Comparison getComparisonOperator(String op) {
switch (op) {
case "EQ": return VectorOperators.EQ;
case "NE": return VectorOperators.NE;
case "LT": return VectorOperators.LT;
case "LE": return VectorOperators.LE;
case "GT": return VectorOperators.GT;
case "GE": return VectorOperators.GE;
case "ULT": return VectorOperators.ULT;
case "ULE": return VectorOperators.ULE;
case "UGT": return VectorOperators.UGT;
case "UGE": return VectorOperators.UGE;
default: throw new IllegalArgumentException("Unknown comparison operator: " + op);
}
}
@CompilerControl(CompilerControl.Mode.INLINE)
protected VectorOperators.Comparison comparison_con() throws Throwable {
return (VectorOperators.Comparison) MUTABLE_COMPARISON_CONSTANT_HANDLE.invokeExact();
}
// Subclasses with different comparison operators
public static class IntegerComparisons extends MaskCompareNotBenchmark {
@Param({"EQ", "NE", "LT", "LE", "GT", "GE", "ULT", "ULE", "UGT", "UGE"})
public String COMPARISON_OP;
@Override
protected String getComparisonOperatorName() {
return COMPARISON_OP;
}
@Benchmark
public void testCompareMaskNotByte() throws Throwable {
VectorOperators.Comparison op = comparison_con();
ByteVector bv = ByteVector.fromArray(B_SPECIES, bb, 0);
for (int j = 0; j < ARRAYLEN; j += B_SPECIES.length()) {
ByteVector av = ByteVector.fromArray(B_SPECIES, ba, j);
VectorMask<Byte> m = av.compare(op, bv).not();
m.intoArray(mr, j);
}
}
@Benchmark
public void testCompareMaskNotShort() throws Throwable {
VectorOperators.Comparison op = comparison_con();
ShortVector bv = ShortVector.fromArray(S_SPECIES, sb, 0);
for (int j = 0; j < ARRAYLEN; j += S_SPECIES.length()) {
ShortVector av = ShortVector.fromArray(S_SPECIES, sa, j);
VectorMask<Short> m = av.compare(op, bv).not();
m.intoArray(mr, j);
}
}
@Benchmark
public void testCompareMaskNotInt() throws Throwable {
VectorOperators.Comparison op = comparison_con();
IntVector bv = IntVector.fromArray(I_SPECIES, ib, 0);
for (int j = 0; j < ARRAYLEN; j += I_SPECIES.length()) {
IntVector av = IntVector.fromArray(I_SPECIES, ia, j);
VectorMask<Integer> m = av.compare(op, bv).not();
m.intoArray(mr, j);
}
}
@Benchmark
public void testCompareMaskNotLong() throws Throwable {
VectorOperators.Comparison op = comparison_con();
LongVector bv = LongVector.fromArray(L_SPECIES, lb, 0);
for (int j = 0; j < ARRAYLEN; j += L_SPECIES.length()) {
LongVector av = LongVector.fromArray(L_SPECIES, la, j);
VectorMask<Long> m = av.compare(op, bv).not();
m.intoArray(mr, j);
}
}
}
public static class FloatingPointComparisons extends MaskCompareNotBenchmark {
// "ULT", "ULE", "UGT", "UGE" are not supported for floating point types
@Param({"EQ", "NE", "LT", "LE", "GT", "GE"})
public String COMPARISON_OP;
@Override
protected String getComparisonOperatorName() {
return COMPARISON_OP;
}
@Benchmark
public void testCompareMaskNotFloat() throws Throwable {
VectorOperators.Comparison op = comparison_con();
FloatVector bv = FloatVector.fromArray(F_SPECIES, fb, 0);
for (int j = 0; j < ARRAYLEN; j += F_SPECIES.length()) {
FloatVector av = FloatVector.fromArray(F_SPECIES, fa, j);
VectorMask<Float> m = av.compare(op, bv).not();
m.intoArray(mr, j);
}
}
@Benchmark
public void testCompareMaskNotDouble() throws Throwable {
VectorOperators.Comparison op = comparison_con();
DoubleVector bv = DoubleVector.fromArray(D_SPECIES, db, 0);
for (int j = 0; j < ARRAYLEN; j += D_SPECIES.length()) {
DoubleVector av = DoubleVector.fromArray(D_SPECIES, da, j);
VectorMask<Double> m = av.compare(op, bv).not();
m.intoArray(mr, j);
}
}
}
}