8352585: Add special case handling for Float16.max/min x86 backend

Reviewed-by: epeter, sviswanathan
This commit is contained in:
Jatin Bhateja 2025-04-03 09:21:55 +00:00
parent 9c5ed23eac
commit f7a94feedd
6 changed files with 254 additions and 6 deletions

View File

@ -13810,6 +13810,16 @@ void Assembler::vcmpps(XMMRegister dst, XMMRegister nds, XMMRegister src, int co
emit_int24((unsigned char)0xC2, (0xC0 | encode), (unsigned char)comparison);
}
void Assembler::evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, ComparisonPredicateFP comparison) {
assert(VM_Version::supports_avx512_fp16(), "");
InstructionAttr attributes(Assembler::AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
attributes.set_is_evex_instruction();
attributes.set_embedded_opmask_register_specifier(mask);
attributes.reset_is_clear_context();
int encode = vex_prefix_and_encode(kdst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F_3A, &attributes);
emit_int24((unsigned char)0xC2, (0xC0 | encode), comparison);
}
void Assembler::evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
ComparisonPredicateFP comparison, int vector_len) {
assert(VM_Version::supports_evex(), "");

View File

@ -3195,6 +3195,9 @@ private:
void evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
ComparisonPredicateFP comparison, int vector_len);
void evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
ComparisonPredicateFP comparison);
// Vector integer compares
void vpcmpgtd(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len);
void evpcmpd(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,

View File

@ -6680,8 +6680,6 @@ void C2_MacroAssembler::efp16sh(int opcode, XMMRegister dst, XMMRegister src1, X
case Op_SubHF: vsubsh(dst, src1, src2); break;
case Op_MulHF: vmulsh(dst, src1, src2); break;
case Op_DivHF: vdivsh(dst, src1, src2); break;
case Op_MaxHF: vmaxsh(dst, src1, src2); break;
case Op_MinHF: vminsh(dst, src1, src2); break;
default: assert(false, "%s", NodeClassNames[opcode]); break;
}
}
@ -7091,3 +7089,48 @@ void C2_MacroAssembler::vector_saturating_op(int ideal_opc, BasicType elem_bt, X
vector_saturating_op(ideal_opc, elem_bt, dst, src1, src2, vlen_enc);
}
}
void C2_MacroAssembler::scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2,
KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc) {
if (opcode == Op_MaxHF) {
// Move sign bits of src2 to mask register.
evpmovw2m(ktmp, src2, vlen_enc);
// xtmp1 = src2 < 0 ? src2 : src1
evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc);
// xtmp2 = src2 < 0 ? ? src1 : src2
evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc);
// Idea behind above swapping is to make seconds source operand a +ve value.
// As per instruction semantic, if the values being compared are both 0.0s (of either sign), the value in
// the second source operand is returned. If only one value is a NaN (SNaN or QNaN) for this instruction,
// the second source operand, either a NaN or a valid floating-point value, is returned
// dst = max(xtmp1, xtmp2)
vmaxsh(dst, xtmp1, xtmp2);
// isNaN = is_unordered_quiet(xtmp1)
evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q);
// Final result is same as first source if its a NaN value,
// in case second operand holds a NaN value then as per above semantics
// result is same as second operand.
Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc);
} else {
assert(opcode == Op_MinHF, "");
// Move sign bits of src1 to mask register.
evpmovw2m(ktmp, src1, vlen_enc);
// xtmp1 = src1 < 0 ? src2 : src1
evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc);
// xtmp2 = src1 < 0 ? src1 : src2
evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc);
// Idea behind above swapping is to make seconds source operand a -ve value.
// As per instruction semantics, if the values being compared are both 0.0s (of either sign), the value in
// the second source operand is returned.
// If only one value is a NaN (SNaN or QNaN) for this instruction, the second source operand, either a NaN
// or a valid floating-point value, is written to the result.
// dst = min(xtmp1, xtmp2)
vminsh(dst, xtmp1, xtmp2);
// isNaN = is_unordered_quiet(xtmp1)
evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q);
// Final result is same as first source if its a NaN value,
// in case second operand holds a NaN value then as per above semantics
// result is same as second operand.
Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc);
}
}

View File

@ -584,4 +584,6 @@ public:
void select_from_two_vectors_evex(BasicType elem_bt, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc);
void scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2,
KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc);
#endif // CPU_X86_C2_MACROASSEMBLER_X86_HPP

View File

@ -1461,11 +1461,14 @@ bool Matcher::match_rule_supported(int opcode) {
return false;
}
break;
case Op_MaxHF:
case Op_MinHF:
if (!VM_Version::supports_avx512vlbw()) {
return false;
} // fallthrough
case Op_AddHF:
case Op_DivHF:
case Op_FmaHF:
case Op_MaxHF:
case Op_MinHF:
case Op_MulHF:
case Op_ReinterpretS2HF:
case Op_ReinterpretHF2S:
@ -10935,8 +10938,6 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
%{
match(Set dst (AddHF src1 src2));
match(Set dst (DivHF src1 src2));
match(Set dst (MaxHF src1 src2));
match(Set dst (MinHF src1 src2));
match(Set dst (MulHF src1 src2));
match(Set dst (SubHF src1 src2));
format %{ "scalar_binop_fp16 $dst, $src1, $src2" %}
@ -10947,6 +10948,20 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
ins_pipe(pipe_slow);
%}
instruct scalar_minmax_HF_reg(regF dst, regF src1, regF src2, kReg ktmp, regF xtmp1, regF xtmp2)
%{
match(Set dst (MaxHF src1 src2));
match(Set dst (MinHF src1 src2));
effect(TEMP_DEF dst, TEMP ktmp, TEMP xtmp1, TEMP xtmp2);
format %{ "scalar_min_max_fp16 $dst, $src1, $src2\t using $ktmp, $xtmp1 and $xtmp2 as TEMP" %}
ins_encode %{
int opcode = this->ideal_Opcode();
__ scalar_max_min_fp16(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, $ktmp$$KRegister,
$xtmp1$$XMMRegister, $xtmp2$$XMMRegister, Assembler::AVX_128bit);
%}
ins_pipe( pipe_slow );
%}
instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2)
%{
match(Set dst (FmaHF src2 (Binary dst src1)));

View File

@ -0,0 +1,175 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package compiler.intrinsics.float16;
import compiler.lib.ir_framework.*;
import jdk.incubator.vector.*;
import java.util.Random;
import jdk.test.lib.*;
/**
* @test
* @bug 8352585
* @library /test/lib /
* @summary Add special case handling for Float16.max/min x86 backend
* @modules jdk.incubator.vector
* @run driver compiler.intrinsics.float16.TestFloat16MaxMinSpecialValues
*/
public class TestFloat16MaxMinSpecialValues {
public static Float16 POS_ZERO = Float16.valueOf(0.0f);
public static Float16 NEG_ZERO = Float16.valueOf(-0.0f);
public static Float16 SRC = Float16.valueOf(Float.MAX_VALUE);
public static Random rd = Utils.getRandomInstance();
public static Float16 genNaN() {
// IEEE 754 Half Precision QNaN Format
// S EEEEE MMMMMMMMMM
// X 11111 1XXXXXXXXX
short sign = (short)(rd.nextBoolean() ? 1 << 15 : 0);
short significand = (short)rd.nextInt(512);
return Float16.shortBitsToFloat16((short)(sign | 0x7E00 | significand));
}
public static boolean assertionCheck(Float16 actual, Float16 expected) {
return !actual.equals(expected);
}
public Float16 RES;
public static void main(String [] args) {
TestFramework.runWithFlags("--add-modules=jdk.incubator.vector");
}
@Test
@IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
public Float16 testMaxNaNOperands(Float16 src1, Float16 src2) {
return Float16.max(src1, src2);
}
@Run(test = "testMaxNaNOperands")
public void launchMaxNaNOperands() {
Float16 NAN = null;
for (int i = 0; i < 100; i++) {
NAN = genNaN();
RES = testMaxNaNOperands(SRC, NAN);
if (assertionCheck(RES, NAN)) {
throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN , expected = NaN, actual = " + RES.floatValue());
}
NAN = genNaN();
RES = testMaxNaNOperands(NAN, SRC);
if (assertionCheck(RES, NAN)) {
throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue());
}
NAN = genNaN();
RES = testMaxNaNOperands(NAN, NAN);
if (assertionCheck(RES, NAN)) {
throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue());
}
}
}
@Test
@IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
public Float16 testMinNaNOperands(Float16 src1, Float16 src2) {
return Float16.min(src1, src2);
}
@Run(test = "testMinNaNOperands")
public void launchMinNaNOperands() {
Float16 NAN = null;
for (int i = 0; i < 100; i++) {
NAN = genNaN();
RES = testMinNaNOperands(SRC, NAN);
if (assertionCheck(RES, NAN)) {
throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN, expected = NaN, actual = " + RES.floatValue());
}
NAN = genNaN();
RES = testMinNaNOperands(NAN, SRC);
if (assertionCheck(RES, NAN)) {
throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue());
}
NAN = genNaN();
RES = testMinNaNOperands(NAN, NAN);
if (assertionCheck(RES, NAN)) {
throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue());
}
}
}
@Test
@IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
public Float16 testMaxZeroOperands(Float16 src1, Float16 src2) {
return Float16.max(src1, src2);
}
@Run(test = "testMaxZeroOperands")
public void launchMaxZeroOperands() {
RES = testMaxZeroOperands(POS_ZERO, NEG_ZERO);
if (assertionCheck(RES, POS_ZERO)) {
throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = +0.0, actual = " + RES.floatValue());
}
RES = testMaxZeroOperands(NEG_ZERO, POS_ZERO);
if (assertionCheck(RES, POS_ZERO)) {
throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
}
RES = testMaxZeroOperands(POS_ZERO, POS_ZERO);
if (assertionCheck(RES, POS_ZERO)) {
throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
}
RES = testMaxZeroOperands(NEG_ZERO, NEG_ZERO);
if (assertionCheck(RES, NEG_ZERO)) {
throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
}
}
@Test
@IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
public Float16 testMinZeroOperands(Float16 src1, Float16 src2) {
return Float16.min(src1, src2);
}
@Run(test = "testMinZeroOperands")
public void launchMinZeroOperands() {
RES = testMinZeroOperands(POS_ZERO, NEG_ZERO);
if (assertionCheck(RES, NEG_ZERO)) {
throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
}
RES = testMinZeroOperands(NEG_ZERO, POS_ZERO);
if (assertionCheck(RES, NEG_ZERO)) {
throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = -0.0, actual = " + RES.floatValue());
}
RES = testMinZeroOperands(POS_ZERO, POS_ZERO);
if (assertionCheck(RES, POS_ZERO)) {
throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
}
RES = testMinZeroOperands(NEG_ZERO, NEG_ZERO);
if (assertionCheck(RES, NEG_ZERO)) {
throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
}
}
}