From f7a94feedd63775a09d0bcb9ef3313972e2a5d69 Mon Sep 17 00:00:00 2001 From: Jatin Bhateja Date: Thu, 3 Apr 2025 09:21:55 +0000 Subject: [PATCH] 8352585: Add special case handling for Float16.max/min x86 backend Reviewed-by: epeter, sviswanathan --- src/hotspot/cpu/x86/assembler_x86.cpp | 10 + src/hotspot/cpu/x86/assembler_x86.hpp | 3 + src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp | 47 ++++- src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp | 2 + src/hotspot/cpu/x86/x86.ad | 23 ++- .../TestFloat16MaxMinSpecialValues.java | 175 ++++++++++++++++++ 6 files changed, 254 insertions(+), 6 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/intrinsics/float16/TestFloat16MaxMinSpecialValues.java diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index 48e482e71ae..29e4fcee2f6 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -13810,6 +13810,16 @@ void Assembler::vcmpps(XMMRegister dst, XMMRegister nds, XMMRegister src, int co emit_int24((unsigned char)0xC2, (0xC0 | encode), (unsigned char)comparison); } +void Assembler::evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, ComparisonPredicateFP comparison) { + assert(VM_Version::supports_avx512_fp16(), ""); + InstructionAttr attributes(Assembler::AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_embedded_opmask_register_specifier(mask); + attributes.reset_is_clear_context(); + int encode = vex_prefix_and_encode(kdst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_0F_3A, &attributes); + emit_int24((unsigned char)0xC2, (0xC0 | encode), comparison); +} + void Assembler::evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, ComparisonPredicateFP comparison, int vector_len) { assert(VM_Version::supports_evex(), ""); diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 5e29961f49a..15ea45edb91 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -3195,6 +3195,9 @@ private: void evcmpps(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, ComparisonPredicateFP comparison, int vector_len); + void evcmpsh(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, + ComparisonPredicateFP comparison); + // Vector integer compares void vpcmpgtd(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); void evpcmpd(KRegister kdst, KRegister mask, XMMRegister nds, XMMRegister src, diff --git a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp index 8cf721f5b20..b6d513f50f2 100644 --- a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp +++ b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp @@ -6680,8 +6680,6 @@ void C2_MacroAssembler::efp16sh(int opcode, XMMRegister dst, XMMRegister src1, X case Op_SubHF: vsubsh(dst, src1, src2); break; case Op_MulHF: vmulsh(dst, src1, src2); break; case Op_DivHF: vdivsh(dst, src1, src2); break; - case Op_MaxHF: vmaxsh(dst, src1, src2); break; - case Op_MinHF: vminsh(dst, src1, src2); break; default: assert(false, "%s", NodeClassNames[opcode]); break; } } @@ -7091,3 +7089,48 @@ void C2_MacroAssembler::vector_saturating_op(int ideal_opc, BasicType elem_bt, X vector_saturating_op(ideal_opc, elem_bt, dst, src1, src2, vlen_enc); } } + +void C2_MacroAssembler::scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2, + KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc) { + if (opcode == Op_MaxHF) { + // Move sign bits of src2 to mask register. + evpmovw2m(ktmp, src2, vlen_enc); + // xtmp1 = src2 < 0 ? src2 : src1 + evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc); + // xtmp2 = src2 < 0 ? ? src1 : src2 + evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc); + // Idea behind above swapping is to make seconds source operand a +ve value. + // As per instruction semantic, if the values being compared are both 0.0s (of either sign), the value in + // the second source operand is returned. If only one value is a NaN (SNaN or QNaN) for this instruction, + // the second source operand, either a NaN or a valid floating-point value, is returned + // dst = max(xtmp1, xtmp2) + vmaxsh(dst, xtmp1, xtmp2); + // isNaN = is_unordered_quiet(xtmp1) + evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q); + // Final result is same as first source if its a NaN value, + // in case second operand holds a NaN value then as per above semantics + // result is same as second operand. + Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc); + } else { + assert(opcode == Op_MinHF, ""); + // Move sign bits of src1 to mask register. + evpmovw2m(ktmp, src1, vlen_enc); + // xtmp1 = src1 < 0 ? src2 : src1 + evpblendmw(xtmp1, ktmp, src1, src2, true, vlen_enc); + // xtmp2 = src1 < 0 ? src1 : src2 + evpblendmw(xtmp2, ktmp, src2, src1, true, vlen_enc); + // Idea behind above swapping is to make seconds source operand a -ve value. + // As per instruction semantics, if the values being compared are both 0.0s (of either sign), the value in + // the second source operand is returned. + // If only one value is a NaN (SNaN or QNaN) for this instruction, the second source operand, either a NaN + // or a valid floating-point value, is written to the result. + // dst = min(xtmp1, xtmp2) + vminsh(dst, xtmp1, xtmp2); + // isNaN = is_unordered_quiet(xtmp1) + evcmpsh(ktmp, k0, xtmp1, xtmp1, Assembler::UNORD_Q); + // Final result is same as first source if its a NaN value, + // in case second operand holds a NaN value then as per above semantics + // result is same as second operand. + Assembler::evmovdquw(dst, ktmp, xtmp1, true, vlen_enc); + } +} diff --git a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp index 4fe2cc397b5..29380609b9a 100644 --- a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp +++ b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp @@ -584,4 +584,6 @@ public: void select_from_two_vectors_evex(BasicType elem_bt, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc); + void scalar_max_min_fp16(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2, + KRegister ktmp, XMMRegister xtmp1, XMMRegister xtmp2, int vlen_enc); #endif // CPU_X86_C2_MACROASSEMBLER_X86_HPP diff --git a/src/hotspot/cpu/x86/x86.ad b/src/hotspot/cpu/x86/x86.ad index 8b2c5835544..afa1a92287d 100644 --- a/src/hotspot/cpu/x86/x86.ad +++ b/src/hotspot/cpu/x86/x86.ad @@ -1461,11 +1461,14 @@ bool Matcher::match_rule_supported(int opcode) { return false; } break; + case Op_MaxHF: + case Op_MinHF: + if (!VM_Version::supports_avx512vlbw()) { + return false; + } // fallthrough case Op_AddHF: case Op_DivHF: case Op_FmaHF: - case Op_MaxHF: - case Op_MinHF: case Op_MulHF: case Op_ReinterpretS2HF: case Op_ReinterpretHF2S: @@ -10935,8 +10938,6 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2) %{ match(Set dst (AddHF src1 src2)); match(Set dst (DivHF src1 src2)); - match(Set dst (MaxHF src1 src2)); - match(Set dst (MinHF src1 src2)); match(Set dst (MulHF src1 src2)); match(Set dst (SubHF src1 src2)); format %{ "scalar_binop_fp16 $dst, $src1, $src2" %} @@ -10947,6 +10948,20 @@ instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2) ins_pipe(pipe_slow); %} +instruct scalar_minmax_HF_reg(regF dst, regF src1, regF src2, kReg ktmp, regF xtmp1, regF xtmp2) +%{ + match(Set dst (MaxHF src1 src2)); + match(Set dst (MinHF src1 src2)); + effect(TEMP_DEF dst, TEMP ktmp, TEMP xtmp1, TEMP xtmp2); + format %{ "scalar_min_max_fp16 $dst, $src1, $src2\t using $ktmp, $xtmp1 and $xtmp2 as TEMP" %} + ins_encode %{ + int opcode = this->ideal_Opcode(); + __ scalar_max_min_fp16(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, $ktmp$$KRegister, + $xtmp1$$XMMRegister, $xtmp2$$XMMRegister, Assembler::AVX_128bit); + %} + ins_pipe( pipe_slow ); +%} + instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2) %{ match(Set dst (FmaHF src2 (Binary dst src1))); diff --git a/test/hotspot/jtreg/compiler/intrinsics/float16/TestFloat16MaxMinSpecialValues.java b/test/hotspot/jtreg/compiler/intrinsics/float16/TestFloat16MaxMinSpecialValues.java new file mode 100644 index 00000000000..f83ca307d84 --- /dev/null +++ b/test/hotspot/jtreg/compiler/intrinsics/float16/TestFloat16MaxMinSpecialValues.java @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package compiler.intrinsics.float16; + +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.*; +import java.util.Random; +import jdk.test.lib.*; + +/** + * @test + * @bug 8352585 + * @library /test/lib / + * @summary Add special case handling for Float16.max/min x86 backend + * @modules jdk.incubator.vector + * @run driver compiler.intrinsics.float16.TestFloat16MaxMinSpecialValues + */ + + +public class TestFloat16MaxMinSpecialValues { + public static Float16 POS_ZERO = Float16.valueOf(0.0f); + public static Float16 NEG_ZERO = Float16.valueOf(-0.0f); + public static Float16 SRC = Float16.valueOf(Float.MAX_VALUE); + public static Random rd = Utils.getRandomInstance(); + + public static Float16 genNaN() { + // IEEE 754 Half Precision QNaN Format + // S EEEEE MMMMMMMMMM + // X 11111 1XXXXXXXXX + short sign = (short)(rd.nextBoolean() ? 1 << 15 : 0); + short significand = (short)rd.nextInt(512); + return Float16.shortBitsToFloat16((short)(sign | 0x7E00 | significand)); + } + + public static boolean assertionCheck(Float16 actual, Float16 expected) { + return !actual.equals(expected); + } + + public Float16 RES; + + public static void main(String [] args) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector"); + } + + @Test + @IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"}) + public Float16 testMaxNaNOperands(Float16 src1, Float16 src2) { + return Float16.max(src1, src2); + } + + @Run(test = "testMaxNaNOperands") + public void launchMaxNaNOperands() { + Float16 NAN = null; + for (int i = 0; i < 100; i++) { + NAN = genNaN(); + RES = testMaxNaNOperands(SRC, NAN); + if (assertionCheck(RES, NAN)) { + throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN , expected = NaN, actual = " + RES.floatValue()); + } + NAN = genNaN(); + RES = testMaxNaNOperands(NAN, SRC); + if (assertionCheck(RES, NAN)) { + throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue()); + } + NAN = genNaN(); + RES = testMaxNaNOperands(NAN, NAN); + if (assertionCheck(RES, NAN)) { + throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue()); + } + } + } + + @Test + @IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"}) + public Float16 testMinNaNOperands(Float16 src1, Float16 src2) { + return Float16.min(src1, src2); + } + + @Run(test = "testMinNaNOperands") + public void launchMinNaNOperands() { + Float16 NAN = null; + for (int i = 0; i < 100; i++) { + NAN = genNaN(); + RES = testMinNaNOperands(SRC, NAN); + if (assertionCheck(RES, NAN)) { + throw new AssertionError("input1 = " + SRC.floatValue() + " input2 = NaN, expected = NaN, actual = " + RES.floatValue()); + } + NAN = genNaN(); + RES = testMinNaNOperands(NAN, SRC); + if (assertionCheck(RES, NAN)) { + throw new AssertionError("input1 = NaN, input2 = " + SRC.floatValue() + ", expected = NaN, actual = " + RES.floatValue()); + } + NAN = genNaN(); + RES = testMinNaNOperands(NAN, NAN); + if (assertionCheck(RES, NAN)) { + throw new AssertionError("input1 = NaN, input2 = NaN, expected = NaN, actual = " + RES.floatValue()); + } + } + } + + @Test + @IR(counts = {IRNode.MAX_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"}) + public Float16 testMaxZeroOperands(Float16 src1, Float16 src2) { + return Float16.max(src1, src2); + } + + @Run(test = "testMaxZeroOperands") + public void launchMaxZeroOperands() { + RES = testMaxZeroOperands(POS_ZERO, NEG_ZERO); + if (assertionCheck(RES, POS_ZERO)) { + throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = +0.0, actual = " + RES.floatValue()); + } + RES = testMaxZeroOperands(NEG_ZERO, POS_ZERO); + if (assertionCheck(RES, POS_ZERO)) { + throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue()); + } + RES = testMaxZeroOperands(POS_ZERO, POS_ZERO); + if (assertionCheck(RES, POS_ZERO)) { + throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue()); + } + RES = testMaxZeroOperands(NEG_ZERO, NEG_ZERO); + if (assertionCheck(RES, NEG_ZERO)) { + throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue()); + } + } + + @Test + @IR(counts = {IRNode.MIN_HF, " >0 "}, applyIfCPUFeatureAnd = {"avx512_fp16", "true", "avx512bw", "true", "avx512vl", "true"}) + public Float16 testMinZeroOperands(Float16 src1, Float16 src2) { + return Float16.min(src1, src2); + } + + @Run(test = "testMinZeroOperands") + public void launchMinZeroOperands() { + RES = testMinZeroOperands(POS_ZERO, NEG_ZERO); + if (assertionCheck(RES, NEG_ZERO)) { + throw new AssertionError("input1 = +0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue()); + } + + RES = testMinZeroOperands(NEG_ZERO, POS_ZERO); + if (assertionCheck(RES, NEG_ZERO)) { + throw new AssertionError("input1 = -0.0, input2 = +0.0, expected = -0.0, actual = " + RES.floatValue()); + } + + RES = testMinZeroOperands(POS_ZERO, POS_ZERO); + if (assertionCheck(RES, POS_ZERO)) { + throw new AssertionError("input1 = +0.0, input2 = +0.0, expected = +0.0, actual = " + RES.floatValue()); + } + + RES = testMinZeroOperands(NEG_ZERO, NEG_ZERO); + if (assertionCheck(RES, NEG_ZERO)) { + throw new AssertionError("input1 = -0.0, input2 = -0.0, expected = -0.0, actual = " + RES.floatValue()); + } + } +}