8352635: Improve inferencing of Float16 operations with constant inputs

Reviewed-by: epeter, sviswanathan
This commit is contained in:
Jatin Bhateja 2025-06-26 15:42:43 +00:00
parent 7f702cf483
commit a49ecb26c5
7 changed files with 455 additions and 142 deletions

View File

@ -264,8 +264,68 @@ Node* ConvF2HFNode::Ideal(PhaseGVN* phase, bool can_reshape) {
return new ReinterpretHF2SNode(binop);
}
}
// Detects following ideal graph pattern
// ConvF2HF(binopF(conF, ConvHF2F(varS))) =>
// ReinterpretHF2SNode(binopHF(conHF, ReinterpretS2HFNode(varS)))
if (Float16NodeFactory::is_float32_binary_oper(in(1)->Opcode())) {
Node* binopF = in(1);
// Check if the incoming binary operation has one floating point constant
// input and the other input is a half precision to single precision upcasting node.
// We land here because a prior HalfFloat to Float conversion promotes
// an integral constant holding Float16 value to a floating point constant.
// i.e. ConvHF2F ConI(short) => ConF
Node* conF = nullptr;
Node* varS = nullptr;
if (binopF->in(1)->is_Con() && binopF->in(2)->Opcode() == Op_ConvHF2F) {
conF = binopF->in(1);
varS = binopF->in(2)->in(1);
} else if (binopF->in(2)->is_Con() && binopF->in(1)->Opcode() == Op_ConvHF2F) {
conF = binopF->in(2);
varS = binopF->in(1)->in(1);
}
if (conF != nullptr &&
varS != nullptr &&
conF->bottom_type()->isa_float_constant() != nullptr &&
Matcher::match_rule_supported(Float16NodeFactory::get_float16_binary_oper(binopF->Opcode())) &&
Matcher::match_rule_supported(Op_ReinterpretS2HF) &&
Matcher::match_rule_supported(Op_ReinterpretHF2S) &&
StubRoutines::hf2f_adr() != nullptr &&
StubRoutines::f2hf_adr() != nullptr) {
jfloat con = conF->bottom_type()->getf();
// Conditions under which floating point constant can be considered for a pattern match.
// 1. conF must lie within Float16 value range, otherwise we would have rounding issues:
// Doing the operation in float32 and then rounding is not the same as
// rounding first and doing the operation in float16.
// 2. If a constant value is one of the valid IEEE 754 binary32 NaN bit patterns
// then it's safe to consider it for pattern match because of the following reasons:
// a. As per section 2.8 of JVMS, Java Virtual Machine does not support
// signaling NaN value.
// b. Any signaling NaN which takes part in a non-comparison expression
// results in a quiet NaN but preserves the significand bits of signaling NaN.
// c. The pattern being matched includes a Float to Float16 conversion after binary
// expression, this downcast will still preserve the significand bits of binary32 NaN.
bool isnan = g_isnan((jdouble)con);
if (StubRoutines::hf2f(StubRoutines::f2hf(con)) == con || isnan) {
Node* newVarHF = phase->transform(new ReinterpretS2HFNode(varS));
Node* conHF = phase->makecon(TypeH::make(con));
Node* binopHF = nullptr;
// Preserving original input order for semantic correctness
// of non-commutative operation.
if (binopF->in(1) == conF) {
binopHF = phase->transform(Float16NodeFactory::make(binopF->Opcode(), binopF->in(0), conHF, newVarHF));
} else {
binopHF = phase->transform(Float16NodeFactory::make(binopF->Opcode(), binopF->in(0), newVarHF, conHF));
}
return new ReinterpretHF2SNode(binopHF);
}
}
}
return nullptr;
}
//=============================================================================
//------------------------------Value------------------------------------------
const Type* ConvF2INode::Value(PhaseGVN* phase) const {

View File

@ -823,39 +823,11 @@ const Type* DivHFNode::Value(PhaseGVN* phase) const {
return bot;
}
// x/x == 1, we ignore 0/0.
// Note: if t1 and t2 are zero then result is NaN (JVMS page 213)
// Does not work for variables because of NaN's
if (in(1) == in(2) && t1->base() == Type::HalfFloatCon &&
!g_isnan(t1->getf()) && g_isfinite(t1->getf()) && t1->getf() != 0.0) { // could be negative ZERO or NaN
return TypeH::ONE;
}
if (t2 == TypeH::ONE) {
return t1;
}
// If divisor is a constant and not zero, divide the numbers
if (t1->base() == Type::HalfFloatCon &&
t2->base() == Type::HalfFloatCon &&
t2->getf() != 0.0) {
// could be negative zero
t2->base() == Type::HalfFloatCon) {
return TypeH::make(t1->getf() / t2->getf());
}
// If the dividend is a constant zero
// Note: if t1 and t2 are zero then result is NaN (JVMS page 213)
// Test TypeHF::ZERO is not sufficient as it could be negative zero
if (t1 == TypeH::ZERO && !g_isnan(t2->getf()) && t2->getf() != 0.0) {
return TypeH::ZERO;
}
// If divisor or dividend is nan then result is nan.
if (g_isnan(t1->getf()) || g_isnan(t2->getf())) {
return TypeH::make(NAN);
}
// Otherwise we give up all hope
return Type::HALF_FLOAT;
}

View File

@ -560,17 +560,12 @@ const Type* SubFPNode::Value(PhaseGVN* phase) const {
//------------------------------sub--------------------------------------------
// A subtract node differences its two inputs.
const Type* SubHFNode::sub(const Type* t1, const Type* t2) const {
// no folding if one of operands is infinity or NaN, do not do constant folding
if(g_isfinite(t1->getf()) && g_isfinite(t2->getf())) {
// Half precision floating point subtraction follows the rules of IEEE 754
// applicable to other floating point types.
if (t1->isa_half_float_constant() != nullptr &&
t2->isa_half_float_constant() != nullptr) {
return TypeH::make(t1->getf() - t2->getf());
}
else if(g_isnan(t1->getf())) {
return t1;
}
else if(g_isnan(t2->getf())) {
return t2;
}
else {
} else {
return Type::HALF_FLOAT;
}
}

View File

@ -302,9 +302,9 @@ public:
const TypeD *isa_double() const; // Returns null if not a Double{Top,Con,Bot}
const TypeD *is_double_constant() const; // Asserts it is a DoubleCon
const TypeD *isa_double_constant() const; // Returns null if not a DoubleCon
const TypeH *isa_half_float() const; // Returns null if not a Float{Top,Con,Bot}
const TypeH *is_half_float_constant() const; // Asserts it is a FloatCon
const TypeH *isa_half_float_constant() const; // Returns null if not a FloatCon
const TypeH *isa_half_float() const; // Returns null if not a HalfFloat{Top,Con,Bot}
const TypeH *is_half_float_constant() const; // Asserts it is a HalfFloatCon
const TypeH *isa_half_float_constant() const; // Returns null if not a HalfFloatCon
const TypeF *isa_float() const; // Returns null if not a Float{Top,Con,Bot}
const TypeF *is_float_constant() const; // Asserts it is a FloatCon
const TypeF *isa_float_constant() const; // Returns null if not a FloatCon

View File

@ -32,10 +32,14 @@
* @run driver TestFloat16ScalarOperations
*/
import compiler.lib.ir_framework.*;
import compiler.lib.verify.*;
import jdk.incubator.vector.Float16;
import static jdk.incubator.vector.Float16.*;
import java.util.Random;
import compiler.lib.generators.Generator;
import static compiler.lib.generators.Generators.G;
public class TestFloat16ScalarOperations {
private static final int count = 1024;
@ -55,19 +59,35 @@ public class TestFloat16ScalarOperations {
private static final Float16 MAX_HALF_ULP = Float16.valueOf(16.0f);
private static final Float16 SIGNALING_NAN = shortBitsToFloat16((short)31807);
private static Random r = jdk.test.lib.Utils.getRandomInstance();
private static Generator<Float> genF = G.uniformFloats();
private static Generator<Short> genHF = G.uniformFloat16s();
private static final Float16 RANDOM1 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue());
private static final Float16 RANDOM2 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue());
private static final Float16 RANDOM3 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue());
private static final Float16 RANDOM4 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue());
private static final Float16 RANDOM5 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue());
private static final Float16 RANDOM1 = Float16.valueOf(genF.next());
private static final Float16 RANDOM2 = Float16.valueOf(genF.next());
private static final Float16 RANDOM3 = Float16.valueOf(genF.next());
private static final Float16 RANDOM4 = Float16.valueOf(genF.next());
private static final Float16 RANDOM5 = Float16.valueOf(genF.next());
private static Float16 RANDOM1_VAR = RANDOM1;
private static Float16 RANDOM2_VAR = RANDOM2;
private static Float16 RANDOM3_VAR = RANDOM3;
private static Float16 RANDOM4_VAR = RANDOM4;
private static Float16 RANDOM5_VAR = RANDOM5;
private static Float16 POSITIVE_ZERO_VAR = POSITIVE_ZERO;
private static final float INEXACT_FP16 = 2051.0f;
private static final float EXACT_FP16 = 2052.0f;
private static final float SNAN_FP16 = Float.intBitsToFloat(0x7F8000F0);
private static final float QNAN_FP16 = Float.intBitsToFloat(0x7FC00000);
private Float16 GOLDEN_DIV_POT;
private Float16 GOLDEN_MUL2;
private short GOLDEN_INEXACT;
private short GOLDEN_EXACT;
private short GOLDEN_RANDOM_PAT1;
private short GOLDEN_RANDOM_PAT2;
private short GOLDEN_SNAN;
private short GOLDEN_QNAN;
public static void main(String args[]) {
Scenario s0 = new Scenario(0, "--add-modules=jdk.incubator.vector", "-Xint");
@ -78,11 +98,19 @@ public class TestFloat16ScalarOperations {
public TestFloat16ScalarOperations() {
src = new short[count];
dst = new short[count];
fl = new float[count];
for (int i = 0; i < count; i++) {
src[i] = Float.floatToFloat16(r.nextFloat() * MAX_VALUE.floatValue());
fl[i] = r.nextFloat();
}
fl = new float[count];
G.fill(genF, fl);
G.fill(genHF, src);
GOLDEN_DIV_POT = testDivByPOT();
GOLDEN_MUL2 = testMulByTWO();
GOLDEN_INEXACT = testInexactFP16ConstantPatterns();
GOLDEN_EXACT = testExactFP16ConstantPatterns();
GOLDEN_RANDOM_PAT1 = testRandomFP16ConstantPatternSet1();
GOLDEN_RANDOM_PAT2 = testRandomFP16ConstantPatternSet2();
GOLDEN_SNAN = testSNaNFP16ConstantPatterns();
GOLDEN_QNAN = testQNaNFP16ConstantPatterns();
}
static void assertResult(float actual, float expected, String msg) {
@ -270,7 +298,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.MUL_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void testDivByPOT() {
public Float16 testDivByPOT() {
Float16 res = valueOf(0.0f);
for (int i = 0; i < 50; i++) {
Float16 divisor = valueOf(8.0f);
@ -281,7 +309,12 @@ public class TestFloat16ScalarOperations {
divisor = valueOf(32.0f);
res = add(res, divide(dividend, divisor));
}
dst[0] = float16ToRawShortBits(res);
return res;
}
@Check(test="testDivByPOT")
public void checkDivByPOT(Float16 actual) {
Verify.checkEQ(Float16.float16ToRawShortBits(GOLDEN_DIV_POT), Float16.float16ToRawShortBits(actual));
}
@Test
@ -289,16 +322,151 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.MUL_HF, " 0 ", IRNode.ADD_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void testMulByTWO() {
public Float16 testMulByTWO() {
Float16 res = valueOf(0.0f);
Float16 multiplier = valueOf(2.0f);
for (int i = 0; i < 20; i++) {
Float16 multiplicand = valueOf((float)i);
res = add(res, multiply(multiplicand, multiplier));
}
assertResult(res.floatValue(), (float)((20 * (20 - 1))/2) * 2.0f, "testMulByTWO");
return res;
}
@Check(test="testMulByTWO")
public void checkMulByTWO(Float16 actual) {
Verify.checkEQ(Float16.float16ToRawShortBits(GOLDEN_MUL2), Float16.float16ToRawShortBits(actual));
}
@Test
@IR(counts = {IRNode.ADD_HF, " 0 ", IRNode.SUB_HF, " 0 ", IRNode.MUL_HF, " 0 ", IRNode.DIV_HF, " 0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " 0 ", IRNode.SUB_HF, " 0 ", IRNode.MUL_HF, " 0 ", IRNode.DIV_HF, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
// Test point checks various floating point operations with exactly one floating
// point constant value passed as left or right argument, it then downcasts the
// result of computation to a float16 value. This pattern is used to infer a
// float16 IR during idealization. Floating point constant input is not-representable
// in float16 value range and is an inexact float16 value thereby preventing
// float16 IR inference.
public short testInexactFP16ConstantPatterns() {
short res = 0;
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() + INEXACT_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() - INEXACT_FP16);
res += Float.floatToFloat16(INEXACT_FP16 * POSITIVE_ZERO_VAR.floatValue());
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() / INEXACT_FP16);
return res;
}
@Check(test="testInexactFP16ConstantPatterns")
public void checkInexactFP16ConstantPatterns(short actual) {
Verify.checkEQ(GOLDEN_INEXACT, actual);
}
@Test
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public short testSNaNFP16ConstantPatterns() {
short res = 0;
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() + SNAN_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() - SNAN_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() * SNAN_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() / SNAN_FP16);
return res;
}
@Check(test="testSNaNFP16ConstantPatterns")
public void checkSNaNFP16ConstantPatterns(short actual) {
Verify.checkEQ(GOLDEN_SNAN, actual);
}
@Test
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public short testQNaNFP16ConstantPatterns() {
short res = 0;
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() + QNAN_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() - QNAN_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() * QNAN_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() / QNAN_FP16);
return res;
}
@Check(test="testQNaNFP16ConstantPatterns")
public void checkQNaNFP16ConstantPatterns(short actual) {
Verify.checkEQ(GOLDEN_QNAN, actual);
}
@Test
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
// Test point checks various floating point operations with exactly one floating
// point constant value passed as left or right argument, it then downcasts the
// result of computation to a float16 value. This pattern is used to infer a
// Float16 IR during idealization. Floating point constant input is representable
// in Float16 value range thereby leading to a successful Float16 IR inference.
public short testExactFP16ConstantPatterns() {
short res = 0;
res += Float.floatToFloat16(EXACT_FP16 + POSITIVE_ZERO_VAR.floatValue());
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() - EXACT_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() * EXACT_FP16);
res += Float.floatToFloat16(POSITIVE_ZERO_VAR.floatValue() / EXACT_FP16);
return res;
}
@Check(test="testExactFP16ConstantPatterns")
public void checkExactFP16ConstantPatterns(short actual) {
Verify.checkEQ(GOLDEN_EXACT, actual);
}
@Test
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public short testRandomFP16ConstantPatternSet1() {
short res = 0;
res += Float.floatToFloat16(RANDOM1_VAR.floatValue() + RANDOM2.floatValue());
res += Float.floatToFloat16(RANDOM2_VAR.floatValue() - RANDOM3.floatValue());
res += Float.floatToFloat16(RANDOM3_VAR.floatValue() * RANDOM4.floatValue());
res += Float.floatToFloat16(RANDOM4_VAR.floatValue() / RANDOM5.floatValue());
return res;
}
@Check(test="testRandomFP16ConstantPatternSet1")
public void checkRandomFP16ConstantPatternSet1(short actual) {
Verify.checkEQ(GOLDEN_RANDOM_PAT1, actual);
}
@Test
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", IRNode.DIV_HF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public short testRandomFP16ConstantPatternSet2() {
short res = 0;
res += Float.floatToFloat16(RANDOM2.floatValue() + RANDOM1_VAR.floatValue());
res += Float.floatToFloat16(RANDOM3.floatValue() - RANDOM2_VAR.floatValue());
res += Float.floatToFloat16(RANDOM4.floatValue() * RANDOM3_VAR.floatValue());
res += Float.floatToFloat16(RANDOM5.floatValue() / RANDOM4_VAR.floatValue());
return res;
}
@Check(test="testRandomFP16ConstantPatternSet2")
public void checkRandomFP16ConstantPatternSet2(short actual) {
Verify.checkEQ(GOLDEN_RANDOM_PAT2, actual);
}
//
// Tests points for various Float16 constant folding transforms. Following figure represents various
@ -373,41 +541,42 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.SUB_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testSubConstantFolding() {
// If either value is NaN, then the result is NaN.
assertResult(subtract(Float16.NaN, valueOf(2.0f)).floatValue(), Float.NaN, "testAddConstantFolding");
assertResult(subtract(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testAddConstantFolding");
assertResult(subtract(Float16.NaN, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding");
assertResult(subtract(Float16.NaN, valueOf(2.0f)).floatValue(), Float.NaN, "testSubConstantFolding");
assertResult(subtract(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testSubConstantFolding");
assertResult(subtract(Float16.NaN, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testSubConstantFolding");
// The difference of two infinities of opposite sign is NaN.
assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding");
assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testSubConstantFolding");
// The difference of two infinities of the same sign is NaN.
assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding");
assertResult(subtract(Float16.NEGATIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding");
assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testSubConstantFolding");
assertResult(subtract(Float16.NEGATIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testSubConstantFolding");
// The difference of an infinity and a finite value is equal to the infinite operand.
assertResult(subtract(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding");
assertResult(subtract(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding");
assertResult(subtract(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testSubConstantFolding");
assertResult(subtract(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testSubConstantFolding");
// The difference of two zeros of opposite sign is positive zero.
assertResult(subtract(NEGATIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testAddConstantFolding");
assertResult(subtract(NEGATIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testSubConstantFolding");
// Number equal to -MAX_VALUE when subtracted by half upl of MAX_VALUE results into -Inf.
assertResult(subtract(NEGATIVE_MAX_VALUE, MAX_HALF_ULP).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding");
assertResult(subtract(NEGATIVE_MAX_VALUE, MAX_HALF_ULP).floatValue(), Float.NEGATIVE_INFINITY, "testSubConstantFolding");
// Number equal to -MAX_VALUE when subtracted by a number less than half upl for MAX_VALUE results into -MAX_VALUE.
assertResult(subtract(NEGATIVE_MAX_VALUE, LT_MAX_HALF_ULP).floatValue(), NEGATIVE_MAX_VALUE.floatValue(), "testAddConstantFolding");
assertResult(subtract(NEGATIVE_MAX_VALUE, LT_MAX_HALF_ULP).floatValue(), NEGATIVE_MAX_VALUE.floatValue(), "testSubConstantFolding");
assertResult(subtract(valueOf(1.0f), valueOf(2.0f)).floatValue(), -1.0f, "testAddConstantFolding");
assertResult(subtract(valueOf(1.0f), valueOf(2.0f)).floatValue(), -1.0f, "testSubConstantFolding");
}
@Test
@Warmup(value = 10000)
@IR(counts = {IRNode.MAX_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.MAX_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testMaxConstantFolding() {
// If either value is NaN, then the result is NaN.
assertResult(max(valueOf(2.0f), Float16.NaN).floatValue(), Float.NaN, "testMaxConstantFolding");
@ -428,6 +597,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.MIN_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testMinConstantFolding() {
// If either value is NaN, then the result is NaN.
assertResult(min(valueOf(2.0f), Float16.NaN).floatValue(), Float.NaN, "testMinConstantFolding");
@ -447,6 +617,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.DIV_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testDivConstantFolding() {
// If either value is NaN, then the result is NaN.
assertResult(divide(Float16.NaN, POSITIVE_ZERO).floatValue(), Float.NaN, "testDivConstantFolding");
@ -489,6 +660,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.MUL_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testMulConstantFolding() {
// If any operand is NaN, the result is NaN.
assertResult(multiply(Float16.NaN, valueOf(4.0f)).floatValue(), Float.NaN, "testMulConstantFolding");
@ -514,6 +686,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.SQRT_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testSqrtConstantFolding() {
// If the argument is NaN or less than zero, then the result is NaN.
assertResult(sqrt(Float16.NaN).floatValue(), Float.NaN, "testSqrtConstantFolding");
@ -535,6 +708,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.FMA_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testFMAConstantFolding() {
// If any argument is NaN, the result is NaN.
assertResult(fma(Float16.NaN, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.NaN, "testFMAConstantFolding");
@ -572,6 +746,7 @@ public class TestFloat16ScalarOperations {
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(failOn = {IRNode.ADD_HF, IRNode.SUB_HF, IRNode.MUL_HF, IRNode.DIV_HF, IRNode.SQRT_HF, IRNode.FMA_HF},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testRounding1() {
dst[0] = float16ToRawShortBits(add(RANDOM1, RANDOM2));
dst[1] = float16ToRawShortBits(subtract(RANDOM2, RANDOM3));
@ -608,13 +783,13 @@ public class TestFloat16ScalarOperations {
}
@Test
@Warmup(value = 10000)
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ",
IRNode.DIV_HF, " >0 ", IRNode.SQRT_HF, " >0 ", IRNode.FMA_HF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zfh", "true"})
@IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ",
IRNode.DIV_HF, " >0 ", IRNode.SQRT_HF, " >0 ", IRNode.FMA_HF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
@Warmup(10000)
public void testRounding2() {
dst[0] = float16ToRawShortBits(add(RANDOM1_VAR, RANDOM2_VAR));
dst[1] = float16ToRawShortBits(subtract(RANDOM2_VAR, RANDOM3_VAR));

View File

@ -605,6 +605,29 @@ public final class Generators {
fillFloat(generator, MemorySegment.ofArray(a));
}
/**
* Fills the memory segments with shorts obtained by calling next on the generator.
*
* @param generator The generator from which to source the values.
* @param ms Memory segment to be filled with random values.
*/
public void fillShort(Generator<Short> generator, MemorySegment ms) {
var layout = ValueLayout.JAVA_SHORT_UNALIGNED;
for (long i = 0; i < ms.byteSize() / layout.byteSize(); i++) {
ms.setAtIndex(layout, i, generator.next());
}
}
/**
* Fill the array with shorts using the distribution of the generator.
*
* @param a Array to be filled with random values.
*/
public void fill(Generator<Short> generator, short[] a) {
fillShort(generator, MemorySegment.ofArray(a));
}
/**
* Fills the memory segments with ints obtained by calling next on the generator.
*

View File

@ -37,6 +37,7 @@ import compiler.lib.ir_framework.*;
import jdk.incubator.vector.Float16;
import static jdk.incubator.vector.Float16.*;
import static java.lang.Float.*;
import java.util.Arrays;
import jdk.test.lib.*;
import compiler.lib.generators.Generator;
import static compiler.lib.generators.Generators.G;
@ -46,9 +47,11 @@ public class TestFloat16VectorOperations {
private short[] input2;
private short[] input3;
private short[] output;
private static short SCALAR_FP16 = (short)0x7777;
private static short FP16_SCALAR = (short)0x7777;
private static final int LEN = 2048;
private static final Float16 FP16_CONST = Float16.valueOf(1023.0f);
public static void main(String args[]) {
// Test with default MaxVectorSize
TestFramework.runWithFlags("--add-modules=jdk.incubator.vector");
@ -60,10 +63,14 @@ public class TestFloat16VectorOperations {
TestFramework.runWithFlags("--add-modules=jdk.incubator.vector", "-XX:MaxVectorSize=64");
}
public static boolean assertResults(short expected, short actual) {
Float16 expected_fp16 = shortBitsToFloat16(expected);
Float16 actual_fp16 = shortBitsToFloat16(actual);
return !expected_fp16.equals(actual_fp16);
public static void assertResults(int arity, short ... values) {
assert values.length == (arity + 2);
Float16 expected_fp16 = shortBitsToFloat16(values[arity]);
Float16 actual_fp16 = shortBitsToFloat16(values[arity + 1]);
if(!expected_fp16.equals(actual_fp16)) {
String inputs = Arrays.toString(Arrays.copyOfRange(values, 0, arity - 1));
throw new AssertionError("Result Mismatch!, input = " + inputs + " actual = " + actual_fp16 + " expected = " + expected_fp16);
}
}
public TestFloat16VectorOperations() {
@ -84,9 +91,9 @@ public class TestFloat16VectorOperations {
@Test
@Warmup(50)
@IR(counts = {IRNode.ADD_VHF, ">= 1"},
@IR(counts = {IRNode.ADD_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.ADD_VHF, ">= 1"},
@IR(counts = {IRNode.ADD_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorAddFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -98,18 +105,16 @@ public class TestFloat16VectorOperations {
public void checkResultAdd() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(float16ToFloat(input1[i]) + float16ToFloat(input2[i]));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.SUB_VHF, ">= 1"},
@IR(counts = {IRNode.SUB_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.SUB_VHF, ">= 1"},
@IR(counts = {IRNode.SUB_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorSubFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -121,18 +126,16 @@ public class TestFloat16VectorOperations {
public void checkResultSub() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(float16ToFloat(input1[i]) - float16ToFloat(input2[i]));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.MUL_VHF, ">= 1"},
@IR(counts = {IRNode.MUL_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.MUL_VHF, ">= 1"},
@IR(counts = {IRNode.MUL_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorMulFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -144,18 +147,15 @@ public class TestFloat16VectorOperations {
public void checkResultMul() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(float16ToFloat(input1[i]) * float16ToFloat(input2[i]));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.DIV_VHF, ">= 1"},
@IR(counts = {IRNode.DIV_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.DIV_VHF, ">= 1"},
@IR(counts = {IRNode.DIV_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorDivFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -167,18 +167,15 @@ public class TestFloat16VectorOperations {
public void checkResultDiv() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(float16ToFloat(input1[i]) / float16ToFloat(input2[i]));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.MIN_VHF, ">= 1"},
@IR(counts = {IRNode.MIN_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.MIN_VHF, ">= 1"},
@IR(counts = {IRNode.MIN_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorMinFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -190,18 +187,15 @@ public class TestFloat16VectorOperations {
public void checkResultMin() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(Math.min(float16ToFloat(input1[i]), float16ToFloat(input2[i])));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.MAX_VHF, ">= 1"},
@IR(counts = {IRNode.MAX_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.MAX_VHF, ">= 1"},
@IR(counts = {IRNode.MAX_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorMaxFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -213,18 +207,15 @@ public class TestFloat16VectorOperations {
public void checkResultMax() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(Math.max(float16ToFloat(input1[i]), float16ToFloat(input2[i])));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.SQRT_VHF, ">= 1"},
@IR(counts = {IRNode.SQRT_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.SQRT_VHF, ">= 1"},
@IR(counts = {IRNode.SQRT_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorSqrtFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -236,18 +227,15 @@ public class TestFloat16VectorOperations {
public void checkResultSqrt() {
for (int i = 0; i < LEN; ++i) {
short expected = float16ToRawShortBits(sqrt(shortBitsToFloat16(input1[i])));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input = " + input1[i] +
" output = " + output[i] + " expected = " + expected);
}
assertResults(1, input1[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.FMA_VHF, ">= 1"},
@IR(counts = {IRNode.FMA_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.FMA_VHF, ">= 1"},
@IR(counts = {IRNode.FMA_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorFmaFloat16() {
for (int i = 0; i < LEN; ++i) {
@ -261,22 +249,19 @@ public class TestFloat16VectorOperations {
for (int i = 0; i < LEN; ++i) {
short expected = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]),
shortBitsToFloat16(input3[i])));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
"input3 = " + input3[i] + " output = " + output[i] + " expected = " + expected);
}
assertResults(3, input1[i], input2[i], input3[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.FMA_VHF, " >= 1"},
@IR(counts = {IRNode.FMA_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.FMA_VHF, ">= 1"},
@IR(counts = {IRNode.FMA_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorFmaFloat16ScalarMixedConstants() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(SCALAR_FP16),
output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(FP16_SCALAR),
shortBitsToFloat16(floatToFloat16(3.0f))));
}
}
@ -284,21 +269,18 @@ public class TestFloat16VectorOperations {
@Check(test="vectorFmaFloat16ScalarMixedConstants")
public void checkResultFmaScalarMixedConstants() {
for (int i = 0; i < LEN; ++i) {
short expected = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(SCALAR_FP16),
short expected = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(FP16_SCALAR),
shortBitsToFloat16(floatToFloat16(3.0f))));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + SCALAR_FP16 +
"input3 = 3.0 " + "output = " + output[i] + " expected = " + expected);
}
assertResults(2, input1[i], FP16_SCALAR, expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.FMA_VHF, " >= 1"},
@IR(counts = {IRNode.FMA_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.FMA_VHF, ">= 1"},
@IR(counts = {IRNode.FMA_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorFmaFloat16MixedConstants() {
short input3 = floatToFloat16(3.0f);
@ -307,15 +289,13 @@ public class TestFloat16VectorOperations {
}
}
@Check(test="vectorFmaFloat16MixedConstants")
public void checkResultFmaMixedConstants() {
short input3 = floatToFloat16(3.0f);
for (int i = 0; i < LEN; ++i) {
short expected = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]), shortBitsToFloat16(input3)));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1[i] + " input2 = " + input2[i] +
"input3 = " + input3 + " output = " + output[i] + " expected = " + expected);
}
assertResults(3, input1[i], input2[i], input3, expected, output[i]);
}
}
@ -341,10 +321,118 @@ public class TestFloat16VectorOperations {
short input3 = floatToFloat16(3.0f);
for (int i = 0; i < LEN; ++i) {
short expected = float16ToRawShortBits(fma(shortBitsToFloat16(input1), shortBitsToFloat16(input2), shortBitsToFloat16(input3)));
if (assertResults(expected, output[i])) {
throw new RuntimeException("Invalid result: [" + i + "] input1 = " + input1 + " input2 = " + input2 +
"input3 = " + input3 + " output = " + output[i] + " expected = " + expected);
}
assertResults(3, input1, input2, input3, expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.ADD_VHF, " >0 "},
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true", "sve", "true"})
@IR(counts = {IRNode.ADD_VHF, " >0 "},
applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"})
public void vectorAddConstInputFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(add(shortBitsToFloat16(input1[i]), FP16_CONST));
}
}
@Check(test="vectorAddConstInputFloat16")
public void checkResultAddConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(float16ToFloat(input1[i]) + FP16_CONST.floatValue());
assertResults(2, input1[i], float16ToRawShortBits(FP16_CONST), expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.SUB_VHF, " >0 "},
applyIfCPUFeature = {"avx512_fp16", "true"})
public void vectorSubConstInputFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(input1[i]), FP16_CONST));
}
}
@Check(test="vectorSubConstInputFloat16")
public void checkResultSubConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(float16ToFloat(input1[i]) - FP16_CONST.floatValue());
assertResults(2, input1[i], float16ToRawShortBits(FP16_CONST), expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.MUL_VHF, " >0 "},
applyIfCPUFeature = {"avx512_fp16", "true"})
public void vectorMulConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(multiply(FP16_CONST, shortBitsToFloat16(input2[i])));
}
}
@Check(test="vectorMulConstantInputFloat16")
public void checkResultMulConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(FP16_CONST.floatValue() * float16ToFloat(input2[i]));
assertResults(2, float16ToRawShortBits(FP16_CONST), input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.DIV_VHF, " >0 "},
applyIfCPUFeature = {"avx512_fp16", "true"})
public void vectorDivConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(divide(FP16_CONST, shortBitsToFloat16(input2[i])));
}
}
@Check(test="vectorDivConstantInputFloat16")
public void checkResultDivConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(FP16_CONST.floatValue() / float16ToFloat(input2[i]));
assertResults(2, float16ToRawShortBits(FP16_CONST), input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.MAX_VHF, " >0 "},
applyIfCPUFeature = {"avx512_fp16", "true"})
public void vectorMaxConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(max(FP16_CONST, shortBitsToFloat16(input2[i])));
}
}
@Check(test="vectorMaxConstantInputFloat16")
public void checkResultMaxConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(Math.max(FP16_CONST.floatValue(), float16ToFloat(input2[i])));
assertResults(2, float16ToRawShortBits(FP16_CONST), input2[i], expected, output[i]);
}
}
@Test
@Warmup(50)
@IR(counts = {IRNode.MIN_VHF, " >0 "},
applyIfCPUFeature = {"avx512_fp16", "true"})
public void vectorMinConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(min(FP16_CONST, shortBitsToFloat16(input2[i])));
}
}
@Check(test="vectorMinConstantInputFloat16")
public void checkResultMinConstantInputFloat16() {
for (int i = 0; i < LEN; ++i) {
short expected = floatToFloat16(Math.min(FP16_CONST.floatValue(), float16ToFloat(input2[i])));
assertResults(2, float16ToRawShortBits(FP16_CONST), input2[i], expected, output[i]);
}
}
}