mirror of
https://github.com/openjdk/jdk.git
synced 2026-03-14 18:03:44 +00:00
8352585: Add special case handling for Float16.max/min x86 backend
Reviewed-by: epeter, sviswanathan
This commit is contained in:
parent
9c5ed23eac
commit
f7a94feedd
@ -13810,6 +13810,16 @@ void Assembler::vcmpps(XMMRegister dst, XMMRegister nds, XMMRegister src, int co
|
||||
emit_int24((unsigned char)0xC2, (0xC0 | encode), (unsigned char)comparison);
|
||||
}
|
||||
|
||||
void Assembler::evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, ComparisonPredicateFP comparison) {
|
||||
assert(VM_Version::supports_avx512_fp16(), "");
|
||||
InstructionAttr attributes(Assembler::AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true);
|
||||
attributes.set_is_evex_instruction();
|
||||
attributes.set_embedded_opmask_register_specifier(mask);
|
||||
attributes.reset_is_clear_context();
|
||||
int encode = vex_prefix_and_encode(kdst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F_3A, &attributes);
|
||||
emit_int24((unsigned char)0xC2, (0xC0 | encode), comparison);
|
||||
}
|
||||
|
||||
void Assembler::evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
|
||||
ComparisonPredicateFP comparison, int vector_len) {
|
||||
assert(VM_Version::supports_evex(), "");
|
||||
|
||||
@ -3195,6 +3195,9 @@ private:
|
||||
void evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
|
||||
ComparisonPredicateFP comparison, int vector_len);
|
||||
|
||||
void evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
|
||||
ComparisonPredicateFP comparison);
|
||||
|
||||
// Vector integer compares
|
||||
void vpcmpgtd(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len);
|
||||
void evpcmpd(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src,
|
||||
|
||||
@ -6680,8 +6680,6 @@ void C2_MacroAssembler::efp16sh(int opcode, XMMRegister dst, XMMRegister src1, X
|
||||
case Op_SubHF: vsubsh(dst, src1, src2); break;
|
||||
case Op_MulHF: vmulsh(dst, src1, src2); break;
|
||||
case Op_DivHF: vdivsh(dst, src1, src2); break;
|
||||
case Op_MaxHF: vmaxsh(dst, src1, src2); break;
|
||||
case Op_MinHF: vminsh(dst, src1, src2); break;
|
||||
default: assert(false, "%s", NodeClassNames[opcode]); break;
|
||||
}
|
||||
}
|
||||
@ -7091,3 +7089,48 @@ void C2_MacroAssembler::vector_saturating_op(int ideal_opc, BasicType elem_bt, X
|
||||
vector_saturating_op(ideal_opc, elem_bt, dst, src1, src2, vlen_enc);
|
||||
}
|
||||
}
|
||||
|
||||
void C2_MacroAssembler::scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2,
|
||||
KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc) {
|
||||
if (opcode == Op_MaxHF) {
|
||||
// Move sign bits of src2 to mask register.
|
||||
evpmovw2m(ktmp, src2, vlen_enc);
|
||||
// xtmp1 = src2 < 0 ? src2 : src1
|
||||
evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc);
|
||||
// xtmp2 = src2 < 0 ? ? src1 : src2
|
||||
evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc);
|
||||
// Idea behind above swapping is to make seconds source operand a +ve value.
|
||||
// As per instruction semantic, if the values being compared are both 0.0s (of either sign), the value in
|
||||
// the second source operand is returned. If only one value is a NaN (SNaN or QNaN) for this instruction,
|
||||
// the second source operand, either a NaN or a valid floating-point value, is returned
|
||||
// dst = max(xtmp1, xtmp2)
|
||||
vmaxsh(dst, xtmp1, xtmp2);
|
||||
// isNaN = is_unordered_quiet(xtmp1)
|
||||
evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q);
|
||||
// Final result is same as first source if its a NaN value,
|
||||
// in case second operand holds a NaN value then as per above semantics
|
||||
// result is same as second operand.
|
||||
Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc);
|
||||
} else {
|
||||
assert(opcode == Op_MinHF, "");
|
||||
// Move sign bits of src1 to mask register.
|
||||
evpmovw2m(ktmp, src1, vlen_enc);
|
||||
// xtmp1 = src1 < 0 ? src2 : src1
|
||||
evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc);
|
||||
// xtmp2 = src1 < 0 ? src1 : src2
|
||||
evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc);
|
||||
// Idea behind above swapping is to make seconds source operand a -ve value.
|
||||
// As per instruction semantics, if the values being compared are both 0.0s (of either sign), the value in
|
||||
// the second source operand is returned.
|
||||
// If only one value is a NaN (SNaN or QNaN) for this instruction, the second source operand, either a NaN
|
||||
// or a valid floating-point value, is written to the result.
|
||||
// dst = min(xtmp1, xtmp2)
|
||||
vminsh(dst, xtmp1, xtmp2);
|
||||
// isNaN = is_unordered_quiet(xtmp1)
|
||||
evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q);
|
||||
// Final result is same as first source if its a NaN value,
|
||||
// in case second operand holds a NaN value then as per above semantics
|
||||
// result is same as second operand.
|
||||
Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc);
|
||||
}
|
||||
}
|
||||
|
||||
@ -584,4 +584,6 @@ public:
|
||||
|
||||
void select_from_two_vectors_evex(BasicType elem_bt, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc);
|
||||
|
||||
void scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2,
|
||||
KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc);
|
||||
#endif // CPU_X86_C2_MACROASSEMBLER_X86_HPP
|
||||
|
||||
@ -1461,11 +1461,14 @@ bool Matcher::match_rule_supported(int opcode) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Op_MaxHF:
|
||||
case Op_MinHF:
|
||||
if (!VM_Version::supports_avx512vlbw()) {
|
||||
return false;
|
||||
} // fallthrough
|
||||
case Op_AddHF:
|
||||
case Op_DivHF:
|
||||
case Op_FmaHF:
|
||||
case Op_MaxHF:
|
||||
case Op_MinHF:
|
||||
case Op_MulHF:
|
||||
case Op_ReinterpretS2HF:
|
||||
case Op_ReinterpretHF2S:
|
||||
@ -10935,8 +10938,6 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
|
||||
%{
|
||||
match(Set dst (AddHF src1 src2));
|
||||
match(Set dst (DivHF src1 src2));
|
||||
match(Set dst (MaxHF src1 src2));
|
||||
match(Set dst (MinHF src1 src2));
|
||||
match(Set dst (MulHF src1 src2));
|
||||
match(Set dst (SubHF src1 src2));
|
||||
format %{ "scalar_binop_fp16 $dst, $src1, $src2" %}
|
||||
@ -10947,6 +10948,20 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
|
||||
ins_pipe(pipe_slow);
|
||||
%}
|
||||
|
||||
instruct scalar_minmax_HF_reg(regF dst, regF src1, regF src2, kReg ktmp, regF xtmp1, regF xtmp2)
|
||||
%{
|
||||
match(Set dst (MaxHF src1 src2));
|
||||
match(Set dst (MinHF src1 src2));
|
||||
effect(TEMP_DEF dst, TEMP ktmp, TEMP xtmp1, TEMP xtmp2);
|
||||
format %{ "scalar_min_max_fp16 $dst, $src1, $src2\t using $ktmp, $xtmp1 and $xtmp2 as TEMP" %}
|
||||
ins_encode %{
|
||||
int opcode = this->ideal_Opcode();
|
||||
__ scalar_max_min_fp16(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, $ktmp$$KRegister,
|
||||
$xtmp1$$XMMRegister, $xtmp2$$XMMRegister, Assembler::AVX_128bit);
|
||||
%}
|
||||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2)
|
||||
%{
|
||||
match(Set dst (FmaHF src2 (Binary dst src1)));
|
||||
|
||||
@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
package compiler.intrinsics.float16;
|
||||
|
||||
import compiler.lib.ir_framework.*;
|
||||
import jdk.incubator.vector.*;
|
||||
import java.util.Random;
|
||||
import jdk.test.lib.*;
|
||||
|
||||
/**
|
||||
* @test
|
||||
* @bug 8352585
|
||||
* @library /test/lib /
|
||||
* @summary Add special case handling for Float16.max/min x86 backend
|
||||
* @modules jdk.incubator.vector
|
||||
* @run driver compiler.intrinsics.float16.TestFloat16MaxMinSpecialValues
|
||||
*/
|
||||
|
||||
|
||||
public class TestFloat16MaxMinSpecialValues {
|
||||
public static Float16 POS_ZERO = Float16.valueOf(0.0f);
|
||||
public static Float16 NEG_ZERO = Float16.valueOf(-0.0f);
|
||||
public static Float16 SRC = Float16.valueOf(Float.MAX_VALUE);
|
||||
public static Random rd = Utils.getRandomInstance();
|
||||
|
||||
public static Float16 genNaN() {
|
||||
// IEEE 754 Half Precision QNaN Format
|
||||
// S EEEEE MMMMMMMMMM
|
||||
// X 11111 1XXXXXXXXX
|
||||
short sign = (short)(rd.nextBoolean() ? 1 << 15 : 0);
|
||||
short significand = (short)rd.nextInt(512);
|
||||
return Float16.shortBitsToFloat16((short)(sign | 0x7E00 | significand));
|
||||
}
|
||||
|
||||
public static boolean assertionCheck(Float16 actual, Float16 expected) {
|
||||
return !actual.equals(expected);
|
||||
}
|
||||
|
||||
public Float16 RES;
|
||||
|
||||
public static void main(String [] args) {
|
||||
TestFramework.runWithFlags("--add-modules=jdk.incubator.vector");
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
|
||||
public Float16 testMaxNaNOperands(Float16 src1, Float16 src2) {
|
||||
return Float16.max(src1, src2);
|
||||
}
|
||||
|
||||
@Run(test = "testMaxNaNOperands")
|
||||
public void launchMaxNaNOperands() {
|
||||
Float16 NAN = null;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
NAN = genNaN();
|
||||
RES = testMaxNaNOperands(SRC, NAN);
|
||||
if (assertionCheck(RES, NAN)) {
|
||||
throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN , expected = NaN, actual = " + RES.floatValue());
|
||||
}
|
||||
NAN = genNaN();
|
||||
RES = testMaxNaNOperands(NAN, SRC);
|
||||
if (assertionCheck(RES, NAN)) {
|
||||
throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue());
|
||||
}
|
||||
NAN = genNaN();
|
||||
RES = testMaxNaNOperands(NAN, NAN);
|
||||
if (assertionCheck(RES, NAN)) {
|
||||
throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
|
||||
public Float16 testMinNaNOperands(Float16 src1, Float16 src2) {
|
||||
return Float16.min(src1, src2);
|
||||
}
|
||||
|
||||
@Run(test = "testMinNaNOperands")
|
||||
public void launchMinNaNOperands() {
|
||||
Float16 NAN = null;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
NAN = genNaN();
|
||||
RES = testMinNaNOperands(SRC, NAN);
|
||||
if (assertionCheck(RES, NAN)) {
|
||||
throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN, expected = NaN, actual = " + RES.floatValue());
|
||||
}
|
||||
NAN = genNaN();
|
||||
RES = testMinNaNOperands(NAN, SRC);
|
||||
if (assertionCheck(RES, NAN)) {
|
||||
throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue());
|
||||
}
|
||||
NAN = genNaN();
|
||||
RES = testMinNaNOperands(NAN, NAN);
|
||||
if (assertionCheck(RES, NAN)) {
|
||||
throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
|
||||
public Float16 testMaxZeroOperands(Float16 src1, Float16 src2) {
|
||||
return Float16.max(src1, src2);
|
||||
}
|
||||
|
||||
@Run(test = "testMaxZeroOperands")
|
||||
public void launchMaxZeroOperands() {
|
||||
RES = testMaxZeroOperands(POS_ZERO, NEG_ZERO);
|
||||
if (assertionCheck(RES, POS_ZERO)) {
|
||||
throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = +0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
RES = testMaxZeroOperands(NEG_ZERO, POS_ZERO);
|
||||
if (assertionCheck(RES, POS_ZERO)) {
|
||||
throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
RES = testMaxZeroOperands(POS_ZERO, POS_ZERO);
|
||||
if (assertionCheck(RES, POS_ZERO)) {
|
||||
throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
RES = testMaxZeroOperands(NEG_ZERO, NEG_ZERO);
|
||||
if (assertionCheck(RES, NEG_ZERO)) {
|
||||
throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"})
|
||||
public Float16 testMinZeroOperands(Float16 src1, Float16 src2) {
|
||||
return Float16.min(src1, src2);
|
||||
}
|
||||
|
||||
@Run(test = "testMinZeroOperands")
|
||||
public void launchMinZeroOperands() {
|
||||
RES = testMinZeroOperands(POS_ZERO, NEG_ZERO);
|
||||
if (assertionCheck(RES, NEG_ZERO)) {
|
||||
throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
|
||||
RES = testMinZeroOperands(NEG_ZERO, POS_ZERO);
|
||||
if (assertionCheck(RES, NEG_ZERO)) {
|
||||
throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = -0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
|
||||
RES = testMinZeroOperands(POS_ZERO, POS_ZERO);
|
||||
if (assertionCheck(RES, POS_ZERO)) {
|
||||
throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
|
||||
RES = testMinZeroOperands(NEG_ZERO, NEG_ZERO);
|
||||
if (assertionCheck(RES, NEG_ZERO)) {
|
||||
throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user