From 3384c6736daf81aba08e15d6340065517747736e Mon Sep 17 00:00:00 2001 From: Bhavana Kilambi Date: Wed, 15 Apr 2026 12:27:56 +0000 Subject: [PATCH] 8366444: Add support for add/mul reduction operations for Float16 Reviewed-by: jbhateja, mchevalier, xgong, epeter --- src/hotspot/cpu/aarch64/aarch64_vector.ad | 90 ++++++- src/hotspot/cpu/aarch64/aarch64_vector_ad.m4 | 139 +++++++--- .../cpu/aarch64/c2_MacroAssembler_aarch64.cpp | 49 ++++ .../cpu/aarch64/c2_MacroAssembler_aarch64.hpp | 3 + src/hotspot/share/adlc/formssel.cpp | 8 +- src/hotspot/share/opto/classes.hpp | 2 + src/hotspot/share/opto/compile.cpp | 8 +- src/hotspot/share/opto/vectornode.cpp | 18 +- src/hotspot/share/opto/vectornode.hpp | 69 ++++- .../compiler/lib/ir_framework/IRNode.java | 10 + .../loopopts/superword/TestReductions.java | 141 +++++++++- .../TestFloat16VectorOperations.java | 240 +++++++++++++++++- .../vector/Float16OperationsBenchmark.java | 19 ++ .../bench/vm/compiler/VectorReduction2.java | 93 ++++++- 14 files changed, 820 insertions(+), 69 deletions(-) diff --git a/src/hotspot/cpu/aarch64/aarch64_vector.ad b/src/hotspot/cpu/aarch64/aarch64_vector.ad index 30b0c9c799b..4c854913e63 100644 --- a/src/hotspot/cpu/aarch64/aarch64_vector.ad +++ b/src/hotspot/cpu/aarch64/aarch64_vector.ad @@ -1,6 +1,6 @@ // // Copyright (c) 2020, 2026, Oracle and/or its affiliates. All rights reserved. -// Copyright (c) 2020, 2025, Arm Limited. All rights reserved. +// Copyright (c) 2020, 2026, Arm Limited. All rights reserved. // DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. // // This code is free software; you can redistribute it and/or modify it @@ -247,10 +247,39 @@ source %{ case Op_MinVHF: case Op_MaxVHF: case Op_SqrtVHF: + if (UseSVE == 0 && !is_feat_fp16_supported()) { + return false; + } + break; + // At the time of writing this, the Vector API has no half-float (FP16) species. + // Consequently, AddReductionVHF and MulReductionVHF are only produced by the + // auto-vectorizer, which requires strictly ordered semantics for FP reductions. + // + // There is no direct Neon instruction that performs strictly ordered floating + // point add reduction. Hence, on Neon only machines, the add reduction operation + // is implemented as a scalarized sequence using half-precision scalar instruction + // FADD which requires FEAT_FP16 and ASIMDHP to be available on the target. + // On SVE machines (UseSVE > 0) however, there is a direct instruction (FADDA) which + // implements strictly ordered floating point add reduction which does not require + // the FEAT_FP16 and ASIMDHP checks as SVE supports half-precision floats by default. + case Op_AddReductionVHF: // FEAT_FP16 is enabled if both "fphp" and "asimdhp" features are supported. // Only the Neon instructions need this check. SVE supports half-precision floats // by default. - if (UseSVE == 0 && !is_feat_fp16_supported()) { + if (length_in_bytes < 8 || (UseSVE == 0 && !is_feat_fp16_supported())) { + return false; + } + break; + case Op_MulReductionVHF: + // There are no direct Neon/SVE instructions that perform strictly ordered + // floating point multiply reduction. + // For vector length ≤ 16 bytes, the reduction is implemented as a scalarized + // sequence using half-precision scalar instruction FMUL. This path requires + // FEAT_FP16 and ASIMDHP to be available on the target. + // For vector length > 16 bytes, this operation is disabled because there is no + // direct SVE instruction that performs a strictly ordered FP16 multiply + // reduction. + if (length_in_bytes < 8 || length_in_bytes > 16 || !is_feat_fp16_supported()) { return false; } break; @@ -300,6 +329,7 @@ source %{ case Op_VectorRearrange: case Op_MulReductionVD: case Op_MulReductionVF: + case Op_MulReductionVHF: case Op_MulReductionVI: case Op_MulReductionVL: case Op_CompressBitsV: @@ -364,6 +394,7 @@ source %{ case Op_VectorMaskCmp: case Op_LoadVectorGather: case Op_StoreVectorScatter: + case Op_AddReductionVHF: case Op_AddReductionVF: case Op_AddReductionVD: case Op_AndReductionV: @@ -3402,6 +3433,44 @@ instruct reduce_non_strict_order_add4F_neon(vRegF dst, vRegF fsrc, vReg vsrc, vR ins_pipe(pipe_slow); %} +// Add Reduction for Half floats (FP16). +// Neon does not provide direct instructions for strictly ordered floating-point add reductions. +// On Neon-only targets (UseSVE = 0), this operation is implemented as a sequence of scalar additions: +// values equal to the vector width are loaded into a vector register, each lane is extracted, +// and its value is accumulated into the running sum, producing a final scalar result. +instruct reduce_addHF_neon(vRegF dst, vRegF fsrc, vReg vsrc, vReg tmp) %{ + predicate(UseSVE == 0); + match(Set dst (AddReductionVHF fsrc vsrc)); + effect(TEMP_DEF dst, TEMP tmp); + format %{ "reduce_addHF $dst, $fsrc, $vsrc\t# 4HF/8HF. KILL $tmp" %} + ins_encode %{ + uint length_in_bytes = Matcher::vector_length_in_bytes(this, $vsrc); + __ neon_reduce_add_fp16($dst$$FloatRegister, $fsrc$$FloatRegister, + $vsrc$$FloatRegister, length_in_bytes, $tmp$$FloatRegister); + %} + ins_pipe(pipe_slow); +%} + +// This rule calculates the reduction result in strict order. Two cases will +// reach here: +// 1. Non strictly-ordered AddReductionVHF when vector size > 128-bits. For example - +// AddReductionVHF generated by Vector API. For vector size > 128-bits, it is more +// beneficial performance-wise to generate direct SVE instruction even if it is +// strictly ordered. +// 2. Strictly-ordered AddReductionVHF. For example - AddReductionVHF generated by +// auto-vectorization on SVE machine. +instruct reduce_addHF_sve(vRegF dst_src1, vReg src2) %{ + predicate(UseSVE > 0); + match(Set dst_src1 (AddReductionVHF dst_src1 src2)); + format %{ "reduce_addHF_sve $dst_src1, $dst_src1, $src2" %} + ins_encode %{ + uint length_in_bytes = Matcher::vector_length_in_bytes(this, $src2); + assert(length_in_bytes == MaxVectorSize, "invalid vector length"); + __ sve_fadda($dst_src1$$FloatRegister, __ H, ptrue, $src2$$FloatRegister); + %} + ins_pipe(pipe_slow); +%} + // This rule calculates the reduction result in strict order. Two cases will // reach here: // 1. Non strictly-ordered AddReductionVF when vector size > 128-bits. For example - @@ -3492,12 +3561,14 @@ instruct reduce_addL_masked(iRegLNoSp dst, iRegL isrc, vReg vsrc, pRegGov pg, vR ins_pipe(pipe_slow); %} -instruct reduce_addF_masked(vRegF dst_src1, vReg src2, pRegGov pg) %{ +instruct reduce_addFHF_masked(vRegF dst_src1, vReg src2, pRegGov pg) %{ predicate(UseSVE > 0); + match(Set dst_src1 (AddReductionVHF (Binary dst_src1 src2) pg)); match(Set dst_src1 (AddReductionVF (Binary dst_src1 src2) pg)); - format %{ "reduce_addF_masked $dst_src1, $pg, $dst_src1, $src2" %} + format %{ "reduce_addFHF_masked $dst_src1, $pg, $dst_src1, $src2" %} ins_encode %{ - __ sve_fadda($dst_src1$$FloatRegister, __ S, + BasicType bt = Matcher::vector_element_basic_type(this, $src2); + __ sve_fadda($dst_src1$$FloatRegister, __ elemType_to_regVariant(bt), $pg$$PRegister, $src2$$FloatRegister); %} ins_pipe(pipe_slow); @@ -3545,14 +3616,17 @@ instruct reduce_mulL(iRegLNoSp dst, iRegL isrc, vReg vsrc) %{ ins_pipe(pipe_slow); %} -instruct reduce_mulF(vRegF dst, vRegF fsrc, vReg vsrc, vReg tmp) %{ + +instruct reduce_mulFHF(vRegF dst, vRegF fsrc, vReg vsrc, vReg tmp) %{ predicate(Matcher::vector_length_in_bytes(n->in(2)) <= 16); + match(Set dst (MulReductionVHF fsrc vsrc)); match(Set dst (MulReductionVF fsrc vsrc)); effect(TEMP_DEF dst, TEMP tmp); - format %{ "reduce_mulF $dst, $fsrc, $vsrc\t# 2F/4F. KILL $tmp" %} + format %{ "reduce_mulFHF $dst, $fsrc, $vsrc\t# 2F/4F/4HF/8HF. KILL $tmp" %} ins_encode %{ uint length_in_bytes = Matcher::vector_length_in_bytes(this, $vsrc); - __ neon_reduce_mul_fp($dst$$FloatRegister, T_FLOAT, $fsrc$$FloatRegister, + BasicType bt = Matcher::vector_element_basic_type(this, $vsrc); + __ neon_reduce_mul_fp($dst$$FloatRegister, bt, $fsrc$$FloatRegister, $vsrc$$FloatRegister, length_in_bytes, $tmp$$FloatRegister); %} ins_pipe(pipe_slow); diff --git a/src/hotspot/cpu/aarch64/aarch64_vector_ad.m4 b/src/hotspot/cpu/aarch64/aarch64_vector_ad.m4 index 48bffb3cf35..58ed234194a 100644 --- a/src/hotspot/cpu/aarch64/aarch64_vector_ad.m4 +++ b/src/hotspot/cpu/aarch64/aarch64_vector_ad.m4 @@ -1,6 +1,6 @@ // // Copyright (c) 2020, 2026, Oracle and/or its affiliates. All rights reserved. -// Copyright (c) 2020, 2025, Arm Limited. All rights reserved. +// Copyright (c) 2020, 2026, Arm Limited. All rights reserved. // DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. // // This code is free software; you can redistribute it and/or modify it @@ -237,10 +237,39 @@ source %{ case Op_MinVHF: case Op_MaxVHF: case Op_SqrtVHF: + if (UseSVE == 0 && !is_feat_fp16_supported()) { + return false; + } + break; + // At the time of writing this, the Vector API has no half-float (FP16) species. + // Consequently, AddReductionVHF and MulReductionVHF are only produced by the + // auto-vectorizer, which requires strictly ordered semantics for FP reductions. + // + // There is no direct Neon instruction that performs strictly ordered floating + // point add reduction. Hence, on Neon only machines, the add reduction operation + // is implemented as a scalarized sequence using half-precision scalar instruction + // FADD which requires FEAT_FP16 and ASIMDHP to be available on the target. + // On SVE machines (UseSVE > 0) however, there is a direct instruction (FADDA) which + // implements strictly ordered floating point add reduction which does not require + // the FEAT_FP16 and ASIMDHP checks as SVE supports half-precision floats by default. + case Op_AddReductionVHF: // FEAT_FP16 is enabled if both "fphp" and "asimdhp" features are supported. // Only the Neon instructions need this check. SVE supports half-precision floats // by default. - if (UseSVE == 0 && !is_feat_fp16_supported()) { + if (length_in_bytes < 8 || (UseSVE == 0 && !is_feat_fp16_supported())) { + return false; + } + break; + case Op_MulReductionVHF: + // There are no direct Neon/SVE instructions that perform strictly ordered + // floating point multiply reduction. + // For vector length ≤ 16 bytes, the reduction is implemented as a scalarized + // sequence using half-precision scalar instruction FMUL. This path requires + // FEAT_FP16 and ASIMDHP to be available on the target. + // For vector length > 16 bytes, this operation is disabled because there is no + // direct SVE instruction that performs a strictly ordered FP16 multiply + // reduction. + if (length_in_bytes < 8 || length_in_bytes > 16 || !is_feat_fp16_supported()) { return false; } break; @@ -290,6 +319,7 @@ source %{ case Op_VectorRearrange: case Op_MulReductionVD: case Op_MulReductionVF: + case Op_MulReductionVHF: case Op_MulReductionVI: case Op_MulReductionVL: case Op_CompressBitsV: @@ -354,6 +384,7 @@ source %{ case Op_VectorMaskCmp: case Op_LoadVectorGather: case Op_StoreVectorScatter: + case Op_AddReductionVHF: case Op_AddReductionVF: case Op_AddReductionVD: case Op_AndReductionV: @@ -2063,6 +2094,25 @@ instruct reduce_non_strict_order_add4F_neon(vRegF dst, vRegF fsrc, vReg vsrc, vR ins_pipe(pipe_slow); %} dnl + +// Add Reduction for Half floats (FP16). +// Neon does not provide direct instructions for strictly ordered floating-point add reductions. +// On Neon-only targets (UseSVE = 0), this operation is implemented as a sequence of scalar additions: +// values equal to the vector width are loaded into a vector register, each lane is extracted, +// and its value is accumulated into the running sum, producing a final scalar result. +instruct reduce_addHF_neon(vRegF dst, vRegF fsrc, vReg vsrc, vReg tmp) %{ + predicate(UseSVE == 0); + match(Set dst (AddReductionVHF fsrc vsrc)); + effect(TEMP_DEF dst, TEMP tmp); + format %{ "reduce_addHF $dst, $fsrc, $vsrc\t# 4HF/8HF. KILL $tmp" %} + ins_encode %{ + uint length_in_bytes = Matcher::vector_length_in_bytes(this, $vsrc); + __ neon_reduce_add_fp16($dst$$FloatRegister, $fsrc$$FloatRegister, + $vsrc$$FloatRegister, length_in_bytes, $tmp$$FloatRegister); + %} + ins_pipe(pipe_slow); +%} +dnl dnl REDUCE_ADD_FP_SVE($1, $2 ) dnl REDUCE_ADD_FP_SVE(type, size) define(`REDUCE_ADD_FP_SVE', ` @@ -2074,21 +2124,26 @@ define(`REDUCE_ADD_FP_SVE', ` // strictly ordered. // 2. Strictly-ordered AddReductionV$1. For example - AddReductionV$1 generated by // auto-vectorization on SVE machine. -instruct reduce_add$1_sve(vReg$1 dst_src1, vReg src2) %{ - predicate(!VM_Version::use_neon_for_vector(Matcher::vector_length_in_bytes(n->in(2))) || - n->as_Reduction()->requires_strict_order()); +instruct reduce_add$1_sve(vReg`'ifelse($1, HF, F, $1) dst_src1, vReg src2) %{ + ifelse($1, HF, + `predicate(UseSVE > 0);', + `predicate(!VM_Version::use_neon_for_vector(Matcher::vector_length_in_bytes(n->in(2))) || + n->as_Reduction()->requires_strict_order());') match(Set dst_src1 (AddReductionV$1 dst_src1 src2)); format %{ "reduce_add$1_sve $dst_src1, $dst_src1, $src2" %} ins_encode %{ - assert(UseSVE > 0, "must be sve"); - uint length_in_bytes = Matcher::vector_length_in_bytes(this, $src2); + ifelse($1, HF, `', + `assert(UseSVE > 0, "must be sve"); + ')dnl +uint length_in_bytes = Matcher::vector_length_in_bytes(this, $src2); assert(length_in_bytes == MaxVectorSize, "invalid vector length"); __ sve_fadda($dst_src1$$FloatRegister, __ $2, ptrue, $src2$$FloatRegister); %} ins_pipe(pipe_slow); %}')dnl dnl -REDUCE_ADD_FP_SVE(F, S) +REDUCE_ADD_FP_SVE(HF, H) +REDUCE_ADD_FP_SVE(F, S) // reduction addD @@ -2129,21 +2184,30 @@ dnl dnl REDUCE_ADD_FP_PREDICATE($1, $2 ) dnl REDUCE_ADD_FP_PREDICATE(insn_name, op_name) define(`REDUCE_ADD_FP_PREDICATE', ` -instruct reduce_add$1_masked(vReg$1 dst_src1, vReg src2, pRegGov pg) %{ +instruct reduce_add$1_masked(vReg$2 dst_src1, vReg src2, pRegGov pg) %{ predicate(UseSVE > 0); - match(Set dst_src1 (AddReductionV$1 (Binary dst_src1 src2) pg)); + ifelse($2, F, + `match(Set dst_src1 (AddReductionVHF (Binary dst_src1 src2) pg)); + match(Set dst_src1 (AddReductionV$2 (Binary dst_src1 src2) pg));', + `match(Set dst_src1 (AddReductionV$2 (Binary dst_src1 src2) pg));') format %{ "reduce_add$1_masked $dst_src1, $pg, $dst_src1, $src2" %} ins_encode %{ - __ sve_fadda($dst_src1$$FloatRegister, __ $2, - $pg$$PRegister, $src2$$FloatRegister); + ifelse($2, F, + `BasicType bt = Matcher::vector_element_basic_type(this, $src2); + ',)dnl +ifelse($2, F, + `__ sve_fadda($dst_src1$$FloatRegister, __ elemType_to_regVariant(bt), + $pg$$PRegister, $src2$$FloatRegister);', + `__ sve_fadda($dst_src1$$FloatRegister, __ $2, + $pg$$PRegister, $src2$$FloatRegister);') %} ins_pipe(pipe_slow); %}')dnl dnl REDUCE_ADD_INT_PREDICATE(I, iRegIorL2I) REDUCE_ADD_INT_PREDICATE(L, iRegL) -REDUCE_ADD_FP_PREDICATE(F, S) -REDUCE_ADD_FP_PREDICATE(D, D) +REDUCE_ADD_FP_PREDICATE(FHF, F) +REDUCE_ADD_FP_PREDICATE(D, D) // ------------------------------ Vector reduction mul ------------------------- @@ -2176,30 +2240,37 @@ instruct reduce_mulL(iRegLNoSp dst, iRegL isrc, vReg vsrc) %{ ins_pipe(pipe_slow); %} -instruct reduce_mulF(vRegF dst, vRegF fsrc, vReg vsrc, vReg tmp) %{ - predicate(Matcher::vector_length_in_bytes(n->in(2)) <= 16); - match(Set dst (MulReductionVF fsrc vsrc)); +dnl REDUCE_MUL_FP($1, $2 ) +dnl REDUCE_MUL_FP(insn_name, op_name) +define(`REDUCE_MUL_FP', ` +instruct reduce_mul$1(vReg$2 dst, vReg$2 ifelse($2, F, fsrc, dsrc), vReg vsrc, vReg tmp) %{ + predicate(Matcher::vector_length_in_bytes(n->in(2)) ifelse($2, F, <=, ==) 16); + ifelse($2, F, + `match(Set dst (MulReductionVHF fsrc vsrc)); + match(Set dst (MulReductionV$2 fsrc vsrc));', + `match(Set dst (MulReductionV$2 dsrc vsrc));') effect(TEMP_DEF dst, TEMP tmp); - format %{ "reduce_mulF $dst, $fsrc, $vsrc\t# 2F/4F. KILL $tmp" %} + ifelse($2, F, + `format %{ "reduce_mul$1 $dst, $fsrc, $vsrc\t# 2F/4F/4HF/8HF. KILL $tmp" %}', + `format %{ "reduce_mul$1 $dst, $dsrc, $vsrc\t# 2D. KILL $tmp" %}') ins_encode %{ - uint length_in_bytes = Matcher::vector_length_in_bytes(this, $vsrc); - __ neon_reduce_mul_fp($dst$$FloatRegister, T_FLOAT, $fsrc$$FloatRegister, - $vsrc$$FloatRegister, length_in_bytes, $tmp$$FloatRegister); + ifelse($2, F, + `uint length_in_bytes = Matcher::vector_length_in_bytes(this, $vsrc); + ',)dnl +ifelse($2, F, + `BasicType bt = Matcher::vector_element_basic_type(this, $vsrc); + ',)dnl +ifelse($2, F, + `__ neon_reduce_mul_fp($dst$$FloatRegister, bt, $fsrc$$FloatRegister, + $vsrc$$FloatRegister, length_in_bytes, $tmp$$FloatRegister);', + `__ neon_reduce_mul_fp($dst$$FloatRegister, T_DOUBLE, $dsrc$$FloatRegister, + $vsrc$$FloatRegister, 16, $tmp$$FloatRegister);') %} ins_pipe(pipe_slow); -%} - -instruct reduce_mulD(vRegD dst, vRegD dsrc, vReg vsrc, vReg tmp) %{ - predicate(Matcher::vector_length_in_bytes(n->in(2)) == 16); - match(Set dst (MulReductionVD dsrc vsrc)); - effect(TEMP_DEF dst, TEMP tmp); - format %{ "reduce_mulD $dst, $dsrc, $vsrc\t# 2D. KILL $tmp" %} - ins_encode %{ - __ neon_reduce_mul_fp($dst$$FloatRegister, T_DOUBLE, $dsrc$$FloatRegister, - $vsrc$$FloatRegister, 16, $tmp$$FloatRegister); - %} - ins_pipe(pipe_slow); -%} +%}')dnl +dnl +REDUCE_MUL_FP(FHF, F) +REDUCE_MUL_FP(D, D) dnl dnl REDUCE_BITWISE_OP_NEON($1, $2 $3 $4 ) diff --git a/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.cpp b/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.cpp index bba37a7a390..3c179f21c14 100644 --- a/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.cpp +++ b/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.cpp @@ -1,5 +1,6 @@ /* * Copyright (c) 2020, 2026, Oracle and/or its affiliates. All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -1883,6 +1884,27 @@ void C2_MacroAssembler::neon_reduce_mul_fp(FloatRegister dst, BasicType bt, BLOCK_COMMENT("neon_reduce_mul_fp {"); switch(bt) { + // The T_SHORT type below is for Float16 type which also uses floating-point + // instructions. + case T_SHORT: + fmulh(dst, fsrc, vsrc); + ext(vtmp, T8B, vsrc, vsrc, 2); + fmulh(dst, dst, vtmp); + ext(vtmp, T8B, vsrc, vsrc, 4); + fmulh(dst, dst, vtmp); + ext(vtmp, T8B, vsrc, vsrc, 6); + fmulh(dst, dst, vtmp); + if (isQ) { + ext(vtmp, T16B, vsrc, vsrc, 8); + fmulh(dst, dst, vtmp); + ext(vtmp, T16B, vsrc, vsrc, 10); + fmulh(dst, dst, vtmp); + ext(vtmp, T16B, vsrc, vsrc, 12); + fmulh(dst, dst, vtmp); + ext(vtmp, T16B, vsrc, vsrc, 14); + fmulh(dst, dst, vtmp); + } + break; case T_FLOAT: fmuls(dst, fsrc, vsrc); ins(vtmp, S, vsrc, 0, 1); @@ -1907,6 +1929,33 @@ void C2_MacroAssembler::neon_reduce_mul_fp(FloatRegister dst, BasicType bt, BLOCK_COMMENT("} neon_reduce_mul_fp"); } +// Vector reduction add for half float type with ASIMD instructions. +void C2_MacroAssembler::neon_reduce_add_fp16(FloatRegister dst, FloatRegister fsrc, FloatRegister vsrc, + unsigned vector_length_in_bytes, FloatRegister vtmp) { + assert(vector_length_in_bytes == 8 || vector_length_in_bytes == 16, "unsupported"); + bool isQ = vector_length_in_bytes == 16; + + BLOCK_COMMENT("neon_reduce_add_fp16 {"); + faddh(dst, fsrc, vsrc); + ext(vtmp, T8B, vsrc, vsrc, 2); + faddh(dst, dst, vtmp); + ext(vtmp, T8B, vsrc, vsrc, 4); + faddh(dst, dst, vtmp); + ext(vtmp, T8B, vsrc, vsrc, 6); + faddh(dst, dst, vtmp); + if (isQ) { + ext(vtmp, T16B, vsrc, vsrc, 8); + faddh(dst, dst, vtmp); + ext(vtmp, T16B, vsrc, vsrc, 10); + faddh(dst, dst, vtmp); + ext(vtmp, T16B, vsrc, vsrc, 12); + faddh(dst, dst, vtmp); + ext(vtmp, T16B, vsrc, vsrc, 14); + faddh(dst, dst, vtmp); + } + BLOCK_COMMENT("} neon_reduce_add_fp16"); +} + // Helper to select logical instruction void C2_MacroAssembler::neon_reduce_logical_helper(int opc, bool is64, Register Rd, Register Rn, Register Rm, diff --git a/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.hpp b/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.hpp index 5964bb60d4f..f96d3ffb863 100644 --- a/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.hpp +++ b/src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.hpp @@ -177,6 +177,9 @@ FloatRegister fsrc, FloatRegister vsrc, unsigned vector_length_in_bytes, FloatRegister vtmp); + void neon_reduce_add_fp16(FloatRegister dst, FloatRegister fsrc, FloatRegister vsrc, + unsigned vector_length_in_bytes, FloatRegister vtmp); + void neon_reduce_logical(int opc, Register dst, BasicType bt, Register isrc, FloatRegister vsrc, unsigned vector_length_in_bytes); diff --git a/src/hotspot/share/adlc/formssel.cpp b/src/hotspot/share/adlc/formssel.cpp index 4dd2bff7c89..5802217c1c1 100644 --- a/src/hotspot/share/adlc/formssel.cpp +++ b/src/hotspot/share/adlc/formssel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1998, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1998, 2026, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -4233,11 +4233,13 @@ int MatchRule::is_expensive() const { strcmp(opType,"PopulateIndex")==0 || strcmp(opType,"AddReductionVI")==0 || strcmp(opType,"AddReductionVL")==0 || + strcmp(opType,"AddReductionVHF")==0 || strcmp(opType,"AddReductionVF")==0 || strcmp(opType,"AddReductionVD")==0 || strcmp(opType,"MulReductionVI")==0 || strcmp(opType,"MulReductionVL")==0 || strcmp(opType,"MulReductionVF")==0 || + strcmp(opType,"MulReductionVHF")==0 || strcmp(opType,"MulReductionVD")==0 || strcmp(opType,"MinReductionV")==0 || strcmp(opType,"MaxReductionV")==0 || @@ -4348,9 +4350,9 @@ bool MatchRule::is_vector() const { "MaxV", "MinV", "MinVHF", "MaxVHF", "UMinV", "UMaxV", "CompressV", "ExpandV", "CompressM", "CompressBitsV", "ExpandBitsV", "AddReductionVI", "AddReductionVL", - "AddReductionVF", "AddReductionVD", + "AddReductionVHF", "AddReductionVF", "AddReductionVD", "MulReductionVI", "MulReductionVL", - "MulReductionVF", "MulReductionVD", + "MulReductionVHF", "MulReductionVF", "MulReductionVD", "MaxReductionV", "MinReductionV", "AndReductionV", "OrReductionV", "XorReductionV", "MulAddVS2VI", "MacroLogicV", diff --git a/src/hotspot/share/opto/classes.hpp b/src/hotspot/share/opto/classes.hpp index 3c1a68d6224..0f67cf90183 100644 --- a/src/hotspot/share/opto/classes.hpp +++ b/src/hotspot/share/opto/classes.hpp @@ -396,6 +396,7 @@ macro(AddVL) macro(AddReductionVL) macro(AddVF) macro(AddVHF) +macro(AddReductionVHF) macro(AddReductionVF) macro(AddVD) macro(AddReductionVD) @@ -413,6 +414,7 @@ macro(MulReductionVI) macro(MulVL) macro(MulReductionVL) macro(MulVF) +macro(MulReductionVHF) macro(MulReductionVF) macro(MulVD) macro(MulReductionVD) diff --git a/src/hotspot/share/opto/compile.cpp b/src/hotspot/share/opto/compile.cpp index db3cbd4109c..e05df8ea716 100644 --- a/src/hotspot/share/opto/compile.cpp +++ b/src/hotspot/share/opto/compile.cpp @@ -3200,10 +3200,10 @@ void Compile::final_graph_reshaping_impl(Node *n, Final_Reshape_Counts& frc, Uni !n->in(2)->is_Con() ) { // right use is not a constant // Check for commutative opcode switch( nop ) { - case Op_AddI: case Op_AddF: case Op_AddD: case Op_AddL: + case Op_AddI: case Op_AddF: case Op_AddD: case Op_AddHF: case Op_AddL: case Op_MaxI: case Op_MaxL: case Op_MaxF: case Op_MaxD: case Op_MinI: case Op_MinL: case Op_MinF: case Op_MinD: - case Op_MulI: case Op_MulF: case Op_MulD: case Op_MulL: + case Op_MulI: case Op_MulF: case Op_MulD: case Op_MulHF: case Op_MulL: case Op_AndL: case Op_XorL: case Op_OrL: case Op_AndI: case Op_XorI: case Op_OrI: { // Move "last use" input to left by swapping inputs @@ -3282,6 +3282,8 @@ void Compile::handle_div_mod_op(Node* n, BasicType bt, bool is_unsigned) { void Compile::final_graph_reshaping_main_switch(Node* n, Final_Reshape_Counts& frc, uint nop, Unique_Node_List& dead_nodes) { switch( nop ) { // Count all float operations that may use FPU + case Op_AddHF: + case Op_MulHF: case Op_AddF: case Op_SubF: case Op_MulF: @@ -3788,10 +3790,12 @@ void Compile::final_graph_reshaping_main_switch(Node* n, Final_Reshape_Counts& f case Op_AddReductionVI: case Op_AddReductionVL: + case Op_AddReductionVHF: case Op_AddReductionVF: case Op_AddReductionVD: case Op_MulReductionVI: case Op_MulReductionVL: + case Op_MulReductionVHF: case Op_MulReductionVF: case Op_MulReductionVD: case Op_MinReductionV: diff --git a/src/hotspot/share/opto/vectornode.cpp b/src/hotspot/share/opto/vectornode.cpp index dbadc18da01..d19aa476196 100644 --- a/src/hotspot/share/opto/vectornode.cpp +++ b/src/hotspot/share/opto/vectornode.cpp @@ -1260,6 +1260,10 @@ int ReductionNode::opcode(int opc, BasicType bt) { assert(bt == T_LONG, "must be"); vopc = Op_AddReductionVL; break; + case Op_AddHF: + assert(bt == T_SHORT, "must be"); + vopc = Op_AddReductionVHF; + break; case Op_AddF: assert(bt == T_FLOAT, "must be"); vopc = Op_AddReductionVF; @@ -1284,6 +1288,10 @@ int ReductionNode::opcode(int opc, BasicType bt) { assert(bt == T_LONG, "must be"); vopc = Op_MulReductionVL; break; + case Op_MulHF: + assert(bt == T_SHORT, "must be"); + vopc = Op_MulReductionVHF; + break; case Op_MulF: assert(bt == T_FLOAT, "must be"); vopc = Op_MulReductionVF; @@ -1432,10 +1440,12 @@ ReductionNode* ReductionNode::make(int opc, Node* ctrl, Node* n1, Node* n2, Basi switch (vopc) { case Op_AddReductionVI: return new AddReductionVINode(ctrl, n1, n2); case Op_AddReductionVL: return new AddReductionVLNode(ctrl, n1, n2); + case Op_AddReductionVHF: return new AddReductionVHFNode(ctrl, n1, n2, requires_strict_order); case Op_AddReductionVF: return new AddReductionVFNode(ctrl, n1, n2, requires_strict_order); case Op_AddReductionVD: return new AddReductionVDNode(ctrl, n1, n2, requires_strict_order); case Op_MulReductionVI: return new MulReductionVINode(ctrl, n1, n2); case Op_MulReductionVL: return new MulReductionVLNode(ctrl, n1, n2); + case Op_MulReductionVHF: return new MulReductionVHFNode(ctrl, n1, n2, requires_strict_order); case Op_MulReductionVF: return new MulReductionVFNode(ctrl, n1, n2, requires_strict_order); case Op_MulReductionVD: return new MulReductionVDNode(ctrl, n1, n2, requires_strict_order); case Op_MinReductionV: return new MinReductionVNode (ctrl, n1, n2); @@ -1613,6 +1623,8 @@ Node* ReductionNode::make_identity_con_scalar(PhaseGVN& gvn, int sopc, BasicType return nullptr; } break; + case Op_AddReductionVHF: + return gvn.makecon(TypeH::ZERO); case Op_AddReductionVI: // fallthrough case Op_AddReductionVL: // fallthrough case Op_AddReductionVF: // fallthrough @@ -1624,6 +1636,8 @@ Node* ReductionNode::make_identity_con_scalar(PhaseGVN& gvn, int sopc, BasicType return gvn.makecon(TypeInt::ONE); case Op_MulReductionVL: return gvn.makecon(TypeLong::ONE); + case Op_MulReductionVHF: + return gvn.makecon(TypeH::ONE); case Op_MulReductionVF: return gvn.makecon(TypeF::ONE); case Op_MulReductionVD: @@ -1716,12 +1730,14 @@ bool ReductionNode::auto_vectorization_requires_strict_order(int vopc) { // These are cases that all have associative operations, which can // thus be reordered, allowing non-strict order reductions. return false; + case Op_AddReductionVHF: + case Op_MulReductionVHF: case Op_AddReductionVF: case Op_MulReductionVF: case Op_AddReductionVD: case Op_MulReductionVD: // Floating-point addition and multiplication are non-associative, - // so AddReductionVF/D and MulReductionVF/D require strict ordering + // so AddReductionVHF/VF/VD and MulReductionVHF/VF/VD require strict ordering // in auto-vectorization. return true; default: diff --git a/src/hotspot/share/opto/vectornode.hpp b/src/hotspot/share/opto/vectornode.hpp index de866898302..91cff9fcae8 100644 --- a/src/hotspot/share/opto/vectornode.hpp +++ b/src/hotspot/share/opto/vectornode.hpp @@ -1,5 +1,6 @@ /* * Copyright (c) 2007, 2026, Oracle and/or its affiliates. All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -322,7 +323,7 @@ class ReductionNode : public Node { virtual uint size_of() const { return sizeof(*this); } // Floating-point addition and multiplication are non-associative, so - // AddReductionVF/D and MulReductionVF/D require strict ordering + // AddReductionVHF/F/D and MulReductionVHF/F/D require strict ordering // in auto-vectorization. Vector API can generate AddReductionVF/D // and MulReductionVF/VD without strict ordering, which can benefit // some platforms. @@ -359,6 +360,35 @@ public: virtual int Opcode() const; }; +// Vector add half float as a reduction +class AddReductionVHFNode : public ReductionNode { +private: + // True if add reduction operation for half floats requires strict ordering. + // As an example - The value is true when add reduction for half floats is auto-vectorized + // as auto-vectorization mandates strict ordering but the value is false when this node + // is generated through VectorAPI as VectorAPI does not impose any such rules on ordering. + const bool _requires_strict_order; + +public: + // _requires_strict_order is set to true by default as mandated by auto-vectorization + AddReductionVHFNode(Node* ctrl, Node* in1, Node* in2, bool requires_strict_order = true) : + ReductionNode(ctrl, in1, in2), _requires_strict_order(requires_strict_order) {} + + int Opcode() const override; + bool requires_strict_order() const override { return _requires_strict_order; } + + uint hash() const override { return Node::hash() + _requires_strict_order; } + + bool cmp(const Node& n) const override { + return Node::cmp(n) && _requires_strict_order == ((ReductionNode&)n).requires_strict_order(); + } + + uint size_of() const override { return sizeof(*this); } + + const Type* bottom_type() const override { return Type::HALF_FLOAT; } + uint ideal_reg() const override { return Op_RegF; } +}; + // Vector add float as a reduction class AddReductionVFNode : public ReductionNode { private: @@ -368,7 +398,7 @@ private: // is generated through VectorAPI as VectorAPI does not impose any such rules on ordering. const bool _requires_strict_order; public: - //_requires_strict_order is set to true by default as mandated by auto-vectorization + // _requires_strict_order is set to true by default as mandated by auto-vectorization AddReductionVFNode(Node* ctrl, Node* in1, Node* in2, bool requires_strict_order = true) : ReductionNode(ctrl, in1, in2), _requires_strict_order(requires_strict_order) {} @@ -394,7 +424,7 @@ private: // is generated through VectorAPI as VectorAPI does not impose any such rules on ordering. const bool _requires_strict_order; public: - //_requires_strict_order is set to true by default as mandated by auto-vectorization + // _requires_strict_order is set to true by default as mandated by auto-vectorization AddReductionVDNode(Node* ctrl, Node* in1, Node* in2, bool requires_strict_order = true) : ReductionNode(ctrl, in1, in2), _requires_strict_order(requires_strict_order) {} @@ -578,6 +608,35 @@ public: virtual int Opcode() const; }; +// Vector multiply half float as a reduction +class MulReductionVHFNode : public ReductionNode { +private: + // True if mul reduction operation for half floats requires strict ordering. + // As an example - The value is true when mul reduction for half floats is auto-vectorized + // as auto-vectorization mandates strict ordering but the value is false when this node + // is generated through VectorAPI as VectorAPI does not impose any such rules on ordering. + const bool _requires_strict_order; + +public: + // _requires_strict_order is set to true by default as mandated by auto-vectorization + MulReductionVHFNode(Node* ctrl, Node* in1, Node* in2, bool requires_strict_order = true) : + ReductionNode(ctrl, in1, in2), _requires_strict_order(requires_strict_order) {} + + int Opcode() const override; + bool requires_strict_order() const override { return _requires_strict_order; } + + uint hash() const override { return Node::hash() + _requires_strict_order; } + + bool cmp(const Node& n) const override { + return Node::cmp(n) && _requires_strict_order == ((ReductionNode&)n).requires_strict_order(); + } + + uint size_of() const override { return sizeof(*this); } + + const Type* bottom_type() const override { return Type::HALF_FLOAT; } + uint ideal_reg() const override { return Op_RegF; } +}; + // Vector multiply float as a reduction class MulReductionVFNode : public ReductionNode { // True if mul reduction operation for floats requires strict ordering. @@ -586,7 +645,7 @@ class MulReductionVFNode : public ReductionNode { // is generated through VectorAPI as VectorAPI does not impose any such rules on ordering. const bool _requires_strict_order; public: - //_requires_strict_order is set to true by default as mandated by auto-vectorization + // _requires_strict_order is set to true by default as mandated by auto-vectorization MulReductionVFNode(Node* ctrl, Node* in1, Node* in2, bool requires_strict_order = true) : ReductionNode(ctrl, in1, in2), _requires_strict_order(requires_strict_order) {} @@ -611,7 +670,7 @@ class MulReductionVDNode : public ReductionNode { // is generated through VectorAPI as VectorAPI does not impose any such rules on ordering. const bool _requires_strict_order; public: - //_requires_strict_order is set to true by default as mandated by auto-vectorization + // _requires_strict_order is set to true by default as mandated by auto-vectorization MulReductionVDNode(Node* ctrl, Node* in1, Node* in2, bool requires_strict_order = true) : ReductionNode(ctrl, in1, in2), _requires_strict_order(requires_strict_order) {} diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java index f3fc4afb170..55d591acdb3 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java @@ -323,6 +323,11 @@ public class IRNode { superWordNodes(ADD_REDUCTION_VF, "AddReductionVF"); } + public static final String ADD_REDUCTION_VHF = PREFIX + "ADD_REDUCTION_VHF" + POSTFIX; + static { + superWordNodes(ADD_REDUCTION_VHF, "AddReductionVHF"); + } + public static final String ADD_REDUCTION_VI = PREFIX + "ADD_REDUCTION_VI" + POSTFIX; static { superWordNodes(ADD_REDUCTION_VI, "AddReductionVI"); @@ -1576,6 +1581,11 @@ public class IRNode { superWordNodes(MUL_REDUCTION_VF, "MulReductionVF"); } + public static final String MUL_REDUCTION_VHF = PREFIX + "MUL_REDUCTION_VHF" + POSTFIX; + static { + superWordNodes(MUL_REDUCTION_VHF, "MulReductionVHF"); + } + public static final String MUL_REDUCTION_VI = PREFIX + "MUL_REDUCTION_VI" + POSTFIX; static { superWordNodes(MUL_REDUCTION_VI, "MulReductionVI"); diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/TestReductions.java b/test/hotspot/jtreg/compiler/loopopts/superword/TestReductions.java index 5c085e6a3a3..97a55ae2074 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/TestReductions.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/TestReductions.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2024, 2026, Oracle and/or its affiliates. All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -25,6 +26,7 @@ * @test id=no-vectorization * @bug 8340093 8342095 * @summary Test vectorization of reduction loops. + * @modules jdk.incubator.vector * @library /test/lib / * @run driver compiler.loopopts.superword.TestReductions P0 */ @@ -33,6 +35,7 @@ * @test id=vanilla * @bug 8340093 8342095 * @summary Test vectorization of reduction loops. + * @modules jdk.incubator.vector * @library /test/lib / * @run driver compiler.loopopts.superword.TestReductions P1 */ @@ -41,6 +44,7 @@ * @test id=force-vectorization * @bug 8340093 8342095 * @summary Test vectorization of reduction loops. + * @modules jdk.incubator.vector * @library /test/lib / * @run driver compiler.loopopts.superword.TestReductions P2 */ @@ -50,10 +54,14 @@ package compiler.loopopts.superword; import java.util.Map; import java.util.HashMap; +import jdk.incubator.vector.Float16; + import compiler.lib.ir_framework.*; import compiler.lib.verify.*; import static compiler.lib.generators.Generators.G; import compiler.lib.generators.Generator; +import static java.lang.Float.floatToFloat16; +import static jdk.incubator.vector.Float16.*; /** * Note: there is a corresponding JMH benchmark: @@ -65,6 +73,7 @@ public class TestReductions { private static final Generator GEN_L = G.longs(); private static final Generator GEN_F = G.floats(); private static final Generator GEN_D = G.doubles(); + private static final Generator GEN_F16 = G.float16s(); private static byte[] in1B = fillRandom(new byte[SIZE]); private static byte[] in2B = fillRandom(new byte[SIZE]); @@ -89,6 +98,9 @@ public class TestReductions { private static double[] in1D = fillRandom(new double[SIZE]); private static double[] in2D = fillRandom(new double[SIZE]); private static double[] in3D = fillRandom(new double[SIZE]); + private static short[] in1F16 = fillRandomFloat16(new short[SIZE]); + private static short[] in2F16 = fillRandomFloat16(new short[SIZE]); + private static short[] in3F16 = fillRandomFloat16(new short[SIZE]); interface TestFunction { Object run(); @@ -102,6 +114,7 @@ public class TestReductions { public static void main(String[] args) { TestFramework framework = new TestFramework(TestReductions.class); + framework.addFlags("--add-modules=jdk.incubator.vector"); switch (args[0]) { case "P0" -> { framework.addFlags("-XX:+UnlockDiagnosticVMOptions", "-XX:AutoVectorizationOverrideProfitability=0"); } case "P1" -> { framework.addFlags("-XX:+UnlockDiagnosticVMOptions", "-XX:AutoVectorizationOverrideProfitability=1"); } @@ -250,6 +263,13 @@ public class TestReductions { tests.put("doubleMinBig", TestReductions::doubleMinBig); tests.put("doubleMaxBig", TestReductions::doubleMaxBig); + tests.put("float16AddSimple", TestReductions::float16AddSimple); + tests.put("float16MulSimple", TestReductions::float16MulSimple); + tests.put("float16AddDotProduct", TestReductions::float16AddDotProduct); + tests.put("float16MulDotProduct", TestReductions::float16MulDotProduct); + tests.put("float16AddBig", TestReductions::float16AddBig); + tests.put("float16MulBig", TestReductions::float16MulBig); + // Compute gold value for all test methods before compilation for (Map.Entry entry : tests.entrySet()) { String name = entry.getKey(); @@ -394,7 +414,14 @@ public class TestReductions { "doubleAddBig", "doubleMulBig", "doubleMinBig", - "doubleMaxBig"}) + "doubleMaxBig", + + "float16AddSimple", + "float16MulSimple", + "float16AddDotProduct", + "float16MulDotProduct", + "float16AddBig", + "float16MulBig"}) public void runTests() { for (Map.Entry entry : tests.entrySet()) { String name = entry.getKey(); @@ -453,6 +480,13 @@ public class TestReductions { return a; } + static short[] fillRandomFloat16(short[] a) { + for (int i = 0; i < a.length; i++) { + a[i] = GEN_F16.next(); + } + return a; + } + // ---------byte***Simple ------------------------------------------------------------ @Test @IR(counts = {IRNode.LOAD_VECTOR_B, IRNode.VECTOR_SIZE + "min(max_int, max_byte)", "> 0", @@ -2628,5 +2662,110 @@ public class TestReductions { return acc; } + // ---------float16***Simple ------------------------------------------------------------ + @Test + @IR(counts = {IRNode.ADD_REDUCTION_VHF, "> 0"}, + applyIfCPUFeature = {"sve", "true"}, + applyIf = {"AutoVectorizationOverrideProfitability", "> 0"}) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, "> 0"}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIf = {"AutoVectorizationOverrideProfitability", "> 0"}) + @IR(failOn = IRNode.ADD_REDUCTION_VHF, + applyIf = {"AutoVectorizationOverrideProfitability", "= 0"}) + private static Float16 float16AddSimple() { + short acc = (short)0; // neutral element + for (int i = 0; i < SIZE; i++) { + acc = float16ToRawShortBits(add(shortBitsToFloat16(acc), shortBitsToFloat16(in1F16[i]))); + } + return shortBitsToFloat16(acc); + } + + @Test + @IR(counts = {IRNode.MUL_REDUCTION_VHF, "> 0"}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIfAnd = {"AutoVectorizationOverrideProfitability", "> 0", "MaxVectorSize", "<=16"}) + @IR(failOn = IRNode.MUL_REDUCTION_VHF, + applyIf = {"AutoVectorizationOverrideProfitability", "= 0"}) + private static Float16 float16MulSimple() { + short acc = floatToFloat16(1.0f); // neutral element + for (int i = 0; i < SIZE; i++) { + acc = float16ToRawShortBits(multiply(shortBitsToFloat16(acc), shortBitsToFloat16(in1F16[i]))); + } + return shortBitsToFloat16(acc); + } + + // ---------float16***DotProduct ------------------------------------------------------------ + @Test + @IR(counts = {IRNode.ADD_REDUCTION_VHF, "> 0"}, + applyIfCPUFeature = {"sve", "true"}, + applyIf = {"AutoVectorizationOverrideProfitability", "> 0"}) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, "> 0"}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIf = {"AutoVectorizationOverrideProfitability", "> 0"}) + @IR(failOn = IRNode.ADD_REDUCTION_VHF, + applyIf = {"AutoVectorizationOverrideProfitability", "= 0"}) + private static Float16 float16AddDotProduct() { + short acc = (short)0; // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 val = multiply(shortBitsToFloat16(in1F16[i]), shortBitsToFloat16(in2F16[i])); + acc = float16ToRawShortBits(add(shortBitsToFloat16(acc), val)); + } + return shortBitsToFloat16(acc); + } + + @Test + @IR(counts = {IRNode.MUL_REDUCTION_VHF, "> 0"}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIfAnd = {"AutoVectorizationOverrideProfitability", "> 0", "MaxVectorSize", "<=16"}) + @IR(failOn = IRNode.MUL_REDUCTION_VHF, + applyIf = {"AutoVectorizationOverrideProfitability", "= 0"}) + private static Float16 float16MulDotProduct() { + short acc = floatToFloat16(1.0f); // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 val = multiply(shortBitsToFloat16(in1F16[i]), shortBitsToFloat16(in2F16[i])); + acc = float16ToRawShortBits(multiply(shortBitsToFloat16(acc), val)); + } + return shortBitsToFloat16(acc); + } + + // ---------float16***Big ------------------------------------------------------------ + @Test + @IR(counts = {IRNode.ADD_REDUCTION_VHF, "> 0"}, + applyIfCPUFeature = {"sve", "true"}, + applyIf = {"AutoVectorizationOverrideProfitability", "> 0"}) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, "> 0"}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIf = {"AutoVectorizationOverrideProfitability", "> 0"}) + @IR(failOn = IRNode.ADD_REDUCTION_VHF, + applyIf = {"AutoVectorizationOverrideProfitability", "= 0"}) + private static Float16 float16AddBig() { + short acc = (short)0; // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 a = shortBitsToFloat16(in1F16[i]); + Float16 b = shortBitsToFloat16(in2F16[i]); + Float16 c = shortBitsToFloat16(in3F16[i]); + Float16 val = add(multiply(a, b), add(multiply(a, c), multiply(b, c))); + acc = float16ToRawShortBits(add(shortBitsToFloat16(acc), val)); + } + return shortBitsToFloat16(acc); + } + + @Test + @IR(counts = {IRNode.MUL_REDUCTION_VHF, "> 0"}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIfAnd = {"AutoVectorizationOverrideProfitability", "> 0", "MaxVectorSize", "<=16"}) + @IR(failOn = IRNode.MUL_REDUCTION_VHF, + applyIf = {"AutoVectorizationOverrideProfitability", "= 0"}) + private static Float16 float16MulBig() { + short acc = floatToFloat16(1.0f); // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 a = shortBitsToFloat16(in1F16[i]); + Float16 b = shortBitsToFloat16(in2F16[i]); + Float16 c = shortBitsToFloat16(in3F16[i]); + Float16 val = add(multiply(a, b), add(multiply(a, c), multiply(b, c))); + acc = float16ToRawShortBits(multiply(shortBitsToFloat16(acc), val)); + } + return shortBitsToFloat16(acc); + } } diff --git a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java index f3c27c4d278..929a70f304a 100644 --- a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java +++ b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java @@ -1,6 +1,6 @@ /* * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. - * Copyright (c) 2025, Arm Limited. All rights reserved. + * Copyright 2025, 2026 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -33,19 +33,21 @@ */ package compiler.vectorization; -import compiler.lib.ir_framework.*; -import jdk.incubator.vector.Float16; -import static jdk.incubator.vector.Float16.*; -import static java.lang.Float.*; -import java.util.Arrays; -import jdk.test.lib.*; import compiler.lib.generators.Generator; +import compiler.lib.ir_framework.*; +import compiler.lib.verify.Verify; +import java.util.Arrays; +import jdk.incubator.vector.Float16; +import jdk.test.lib.*; import static compiler.lib.generators.Generators.G; +import static java.lang.Float.*; +import static jdk.incubator.vector.Float16.*; public class TestFloat16VectorOperations { private short[] input1; private short[] input2; private short[] input3; + private Float16[] input4; private short[] output; private static short FP16_SCALAR = (short)0x7777; private static final int LEN = 2048; @@ -77,6 +79,7 @@ public class TestFloat16VectorOperations { input1 = new short[LEN]; input2 = new short[LEN]; input3 = new short[LEN]; + input4 = new Float16[LEN]; output = new short[LEN]; short min_value = float16ToRawShortBits(Float16.MIN_VALUE); @@ -86,6 +89,7 @@ public class TestFloat16VectorOperations { input1[i] = gen.next(); input2[i] = gen.next(); input3[i] = gen.next(); + input4[i] = shortBitsToFloat16(gen.next()); } } @@ -349,7 +353,9 @@ public class TestFloat16VectorOperations { @Test @Warmup(50) @IR(counts = {IRNode.SUB_VHF, " >0 "}, - applyIfCPUFeature = {"avx512_fp16", "true"}) + applyIfCPUFeatureOr = {"avx512_fp16", "true", "sve", "true"}) + @IR(counts = {IRNode.SUB_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) public void vectorSubConstInputFloat16() { for (int i = 0; i < LEN; ++i) { output[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(input1[i]), FP16_CONST)); @@ -367,7 +373,9 @@ public class TestFloat16VectorOperations { @Test @Warmup(50) @IR(counts = {IRNode.MUL_VHF, " >0 "}, - applyIfCPUFeature = {"avx512_fp16", "true"}) + applyIfCPUFeatureOr = {"avx512_fp16", "true", "sve", "true"}) + @IR(counts = {IRNode.MUL_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) public void vectorMulConstantInputFloat16() { for (int i = 0; i < LEN; ++i) { output[i] = float16ToRawShortBits(multiply(FP16_CONST, shortBitsToFloat16(input2[i]))); @@ -385,7 +393,9 @@ public class TestFloat16VectorOperations { @Test @Warmup(50) @IR(counts = {IRNode.DIV_VHF, " >0 "}, - applyIfCPUFeature = {"avx512_fp16", "true"}) + applyIfCPUFeatureOr = {"avx512_fp16", "true", "sve", "true"}) + @IR(counts = {IRNode.DIV_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) public void vectorDivConstantInputFloat16() { for (int i = 0; i < LEN; ++i) { output[i] = float16ToRawShortBits(divide(FP16_CONST, shortBitsToFloat16(input2[i]))); @@ -403,7 +413,9 @@ public class TestFloat16VectorOperations { @Test @Warmup(50) @IR(counts = {IRNode.MAX_VHF, " >0 "}, - applyIfCPUFeature = {"avx512_fp16", "true"}) + applyIfCPUFeatureOr = {"avx512_fp16", "true", "sve", "true"}) + @IR(counts = {IRNode.MAX_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) public void vectorMaxConstantInputFloat16() { for (int i = 0; i < LEN; ++i) { output[i] = float16ToRawShortBits(max(FP16_CONST, shortBitsToFloat16(input2[i]))); @@ -421,7 +433,9 @@ public class TestFloat16VectorOperations { @Test @Warmup(50) @IR(counts = {IRNode.MIN_VHF, " >0 "}, - applyIfCPUFeature = {"avx512_fp16", "true"}) + applyIfCPUFeatureOr = {"avx512_fp16", "true", "sve", "true"}) + @IR(counts = {IRNode.MIN_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) public void vectorMinConstantInputFloat16() { for (int i = 0; i < LEN; ++i) { output[i] = float16ToRawShortBits(min(FP16_CONST, shortBitsToFloat16(input2[i]))); @@ -435,4 +449,206 @@ public class TestFloat16VectorOperations { assertResults(2, float16ToRawShortBits(FP16_CONST), input2[i], expected, output[i]); } } + + @Test + @Warmup(50) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, " >0 "}, + applyIfCPUFeature = {"sve", "true"}) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) + public short vectorAddReductionFloat16() { + short result = (short) 0; + for (int i = 0; i < LEN; i++) { + result = float16ToRawShortBits(add(shortBitsToFloat16(result), shortBitsToFloat16(input1[i]))); + } + return result; + } + + @Check(test="vectorAddReductionFloat16") + public void checkResultAddReductionFloat16() { + short expected = (short) 0; + for (int i = 0; i < LEN; ++i) { + expected = floatToFloat16(float16ToFloat(expected) + float16ToFloat(input1[i])); + } + Verify.checkEQ(shortBitsToFloat16(expected), shortBitsToFloat16(vectorAddReductionFloat16())); + } + + @Test + @Warmup(50) + @IR(counts = {IRNode.MUL_REDUCTION_VHF, " >0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIf = {"MaxVectorSize", "<=16"}) + public short vectorMulReductionFloat16() { + short result = floatToFloat16(1.0f); + for (int i = 0; i < LEN; i++) { + result = float16ToRawShortBits(multiply(shortBitsToFloat16(result), shortBitsToFloat16(input1[i]))); + } + return result; + } + + @Check(test="vectorMulReductionFloat16") + public void checkResultMulReductionFloat16() { + short expected = floatToFloat16(1.0f); + for (int i = 0; i < LEN; ++i) { + expected = floatToFloat16(float16ToFloat(expected) * float16ToFloat(input1[i])); + } + Verify.checkEQ(shortBitsToFloat16(expected), shortBitsToFloat16(vectorMulReductionFloat16())); + } + + // This test case verifies that autovectorization takes place in scenarios where masked + // add reduction instructions are required to be generated on platforms that support + // such masked/partial instructions. + @Test + @Warmup(500) + @IR(counts = {"reduce_addFHF_masked", " >0 "}, phase = {CompilePhase.FINAL_CODE}, + applyIfCPUFeature = {"sve", "true"}) + public short vectorAddReductionFloat16Partial() { + short result = (short) 0; + for (int i = 0; i < LEN; i+=8) { + result = float16ToRawShortBits(add(shortBitsToFloat16(result), shortBitsToFloat16(input1[i]))); + result = float16ToRawShortBits(add(shortBitsToFloat16(result), shortBitsToFloat16(input1[i+1]))); + result = float16ToRawShortBits(add(shortBitsToFloat16(result), shortBitsToFloat16(input1[i+2]))); + result = float16ToRawShortBits(add(shortBitsToFloat16(result), shortBitsToFloat16(input1[i+3]))); + } + return result; + } + + @Check(test="vectorAddReductionFloat16Partial") + public void checkResultAddReductionFloat16Partial() { + short expected = (short) 0; + for (int i = 0; i < LEN; i+=8) { + expected = floatToFloat16(float16ToFloat(expected) + float16ToFloat(input1[i])); + expected = floatToFloat16(float16ToFloat(expected) + float16ToFloat(input1[i+1])); + expected = floatToFloat16(float16ToFloat(expected) + float16ToFloat(input1[i+2])); + expected = floatToFloat16(float16ToFloat(expected) + float16ToFloat(input1[i+3])); + } + Verify.checkEQ(shortBitsToFloat16(expected), shortBitsToFloat16(vectorAddReductionFloat16Partial())); + } + + // Partial multiply reduction for floating point is disabled on AArch64. This test makes sure that code that performs such partial + // multiply reduction operation for FP16 runs without any failures/result mismatch. + @Test + @Warmup(500) + public short vectorMulReductionFloat16Partial() { + short result = floatToFloat16(1.0f); + for (int i = 0; i < LEN; i+=8) { + result = float16ToRawShortBits(multiply(shortBitsToFloat16(result), shortBitsToFloat16(input1[i]))); + result = float16ToRawShortBits(multiply(shortBitsToFloat16(result), shortBitsToFloat16(input1[i+1]))); + result = float16ToRawShortBits(multiply(shortBitsToFloat16(result), shortBitsToFloat16(input1[i+2]))); + result = float16ToRawShortBits(multiply(shortBitsToFloat16(result), shortBitsToFloat16(input1[i+3]))); + } + return result; + } + + @Check(test="vectorMulReductionFloat16Partial") + public void checkResultMulReductionFloat16Partial() { + short expected = floatToFloat16(1.0f); + for (int i = 0; i < LEN; i+=8) { + expected = floatToFloat16(float16ToFloat(expected) * float16ToFloat(input1[i])); + expected = floatToFloat16(float16ToFloat(expected) * float16ToFloat(input1[i+1])); + expected = floatToFloat16(float16ToFloat(expected) * float16ToFloat(input1[i+2])); + expected = floatToFloat16(float16ToFloat(expected) * float16ToFloat(input1[i+3])); + } + Verify.checkEQ(shortBitsToFloat16(expected), shortBitsToFloat16(vectorMulReductionFloat16Partial())); + } + + // This test case verifies that autovectorization does NOT take place when using Float16. + // Filed RFE: JDK-8375321 + @Test + @Warmup(50) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, " =0 "}, + applyIfCPUFeature = {"sve", "true"}) + @IR(counts = {IRNode.ADD_REDUCTION_VHF, " =0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}) + public Float16 vectorAddReductionFloat16NotVectorized() { + Float16 result = Float16.valueOf(0.0f); + for (int i = 0; i < LEN; i++) { + result = add(result, input4[i]); + } + return result; + } + + @Check(test="vectorAddReductionFloat16NotVectorized") + public void checkResultAddReductionFloat16NotVectorized() { + Float16 expected = Float16.valueOf(0.0f); + for (int i = 0; i < LEN; ++i) { + expected = Float16.valueOf(expected.floatValue() + input4[i].floatValue()); + } + Verify.checkEQ(expected, vectorAddReductionFloat16NotVectorized()); + } + + @Test + @Warmup(50) + @IR(counts = {IRNode.MUL_REDUCTION_VHF, " =0 "}, + applyIfCPUFeatureAnd = {"fphp", "true", "asimdhp", "true"}, + applyIf = {"MaxVectorSize", "<=16"}) + public Float16 vectorMulReductionFloat16NotVectorized() { + Float16 result = Float16.valueOf(1.0f); + for (int i = 0; i < LEN; i++) { + result = multiply(result, input4[i]); + } + return result; + } + + @Check(test="vectorMulReductionFloat16NotVectorized") + public void checkResultMulReductionFloat16NotVectorized() { + Float16 expected = Float16.valueOf(1.0f); + for (int i = 0; i < LEN; ++i) { + expected = Float16.valueOf(expected.floatValue() * input4[i].floatValue()); + } + Verify.checkEQ(expected, vectorMulReductionFloat16NotVectorized()); + } + + @Test + @Warmup(500) + @IR(counts = {"reduce_addFHF_masked", " =0 "}, phase = {CompilePhase.FINAL_CODE}, + applyIfCPUFeature = {"sve", "true"}) + public Float16 vectorAddReductionFloat16PartialNotVectorized() { + Float16 result = Float16.valueOf(0.0f); + for (int i = 0; i < LEN; i += 8) { + result = add(result, input4[i]); + result = add(result, input4[i + 1]); + result = add(result, input4[i + 2]); + result = add(result, input4[i + 3]); + } + return result; + } + + @Check(test="vectorAddReductionFloat16PartialNotVectorized") + public void checkResultAddReductionFloat16PartialNotVectorized() { + Float16 expected = Float16.valueOf(0.0f); + for (int i = 0; i < LEN; i += 8) { + expected = Float16.valueOf(expected.floatValue() + input4[i].floatValue()); + expected = Float16.valueOf(expected.floatValue() + input4[i + 1].floatValue()); + expected = Float16.valueOf(expected.floatValue() + input4[i + 2].floatValue()); + expected = Float16.valueOf(expected.floatValue() + input4[i + 3].floatValue()); + } + Verify.checkEQ(expected, vectorAddReductionFloat16PartialNotVectorized()); + } + + @Test + @Warmup(500) + public Float16 vectorMulReductionFloat16PartialNotVectorized() { + Float16 result = Float16.valueOf(1.0f); + for (int i = 0; i < LEN; i += 8) { + result = multiply(result, input4[i]); + result = multiply(result, input4[i + 1]); + result = multiply(result, input4[i + 2]); + result = multiply(result, input4[i + 3]); + } + return result; + } + + @Check(test="vectorMulReductionFloat16PartialNotVectorized") + public void checkResultMulReductionFloat16PartialNotVectorized() { + Float16 expected = Float16.valueOf(1.0f); + for (int i = 0; i < LEN; i += 8) { + expected = Float16.valueOf(expected.floatValue() * input4[i].floatValue()); + expected = Float16.valueOf(expected.floatValue() * input4[i + 1].floatValue()); + expected = Float16.valueOf(expected.floatValue() * input4[i + 2].floatValue()); + expected = Float16.valueOf(expected.floatValue() * input4[i + 3].floatValue()); + } + Verify.checkEQ(expected, vectorMulReductionFloat16PartialNotVectorized()); + } + } diff --git a/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java b/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java index 92c0b58005f..daf18af528e 100644 --- a/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java +++ b/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2025, 2026, Oracle and/or its affiliates. All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -350,4 +351,22 @@ public class Float16OperationsBenchmark { } return distRes; } + + @Benchmark + public short reductionAddFP16() { + short result = (short) 0; + for (int i = 0; i < vectorDim; i++) { + result = float16ToRawShortBits(add(shortBitsToFloat16(result), shortBitsToFloat16(vector1[i]))); + } + return result; + } + + @Benchmark + public short reductionMulFP16() { + short result = floatToFloat16(1.0f); + for (int i = 0; i < vectorDim; i++) { + result = float16ToRawShortBits(multiply(shortBitsToFloat16(result), shortBitsToFloat16(vector1[i]))); + } + return result; + } } diff --git a/test/micro/org/openjdk/bench/vm/compiler/VectorReduction2.java b/test/micro/org/openjdk/bench/vm/compiler/VectorReduction2.java index 9241aca1dad..0d11705c8ec 100644 --- a/test/micro/org/openjdk/bench/vm/compiler/VectorReduction2.java +++ b/test/micro/org/openjdk/bench/vm/compiler/VectorReduction2.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,6 +28,7 @@ import org.openjdk.jmh.infra.*; import java.util.concurrent.TimeUnit; import java.util.Random; +import jdk.incubator.vector.Float16; /** * Note: there is a corresponding IR test: @@ -64,6 +66,9 @@ public abstract class VectorReduction2 { private double[] in1D; private double[] in2D; private double[] in3D; + private short[] in1F16; + private short[] in2F16; + private short[] in3F16; @Param("0") private int seed; @@ -96,6 +101,9 @@ public abstract class VectorReduction2 { in1D = new double[SIZE]; in2D = new double[SIZE]; in3D = new double[SIZE]; + in1F16 = new short[SIZE]; + in2F16 = new short[SIZE]; + in3F16 = new short[SIZE]; for (int i = 0; i < SIZE; i++) { in1B[i] = (byte)r.nextInt(); @@ -121,6 +129,9 @@ public abstract class VectorReduction2 { in1D[i] = r.nextDouble(); in2D[i] = r.nextDouble(); in3D[i] = r.nextDouble(); + in1F16[i] = Float.floatToFloat16(r.nextFloat()); + in2F16[i] = Float.floatToFloat16(r.nextFloat()); + in3F16[i] = Float.floatToFloat16(r.nextFloat()); } } @@ -1449,10 +1460,86 @@ public abstract class VectorReduction2 { bh.consume(acc); } - @Fork(value = 1, jvmArgs = {"-XX:+UseSuperWord"}) + // ---------float16***Simple ------------------------------------------------------------ + @Benchmark + public void float16AddSimple(Blackhole bh) { + short acc = (short)0; // neutral element + for (int i = 0; i < SIZE; i++) { + acc = Float16.float16ToRawShortBits( + Float16.add(Float16.shortBitsToFloat16(acc), Float16.shortBitsToFloat16(in1F16[i]))); + } + bh.consume(acc); + } + + @Benchmark + public void float16MulSimple(Blackhole bh) { + short acc = Float.floatToFloat16(1.0f); // neutral element + for (int i = 0; i < SIZE; i++) { + acc = Float16.float16ToRawShortBits( + Float16.multiply(Float16.shortBitsToFloat16(acc), Float16.shortBitsToFloat16(in1F16[i]))); + } + bh.consume(acc); + } + + // ---------float16***DotProduct ------------------------------------------------------------ + @Benchmark + public void float16AddDotProduct(Blackhole bh) { + short acc = (short)0; // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 val = Float16.multiply(Float16.shortBitsToFloat16(in1F16[i]), + Float16.shortBitsToFloat16(in2F16[i])); + acc = Float16.float16ToRawShortBits( + Float16.add(Float16.shortBitsToFloat16(acc), val)); + } + bh.consume(acc); + } + + @Benchmark + public void float16MulDotProduct(Blackhole bh) { + short acc = Float.floatToFloat16(1.0f); // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 val = Float16.multiply(Float16.shortBitsToFloat16(in1F16[i]), + Float16.shortBitsToFloat16(in2F16[i])); + acc = Float16.float16ToRawShortBits( + Float16.multiply(Float16.shortBitsToFloat16(acc), val)); + } + bh.consume(acc); + } + + // ---------float16***Big ------------------------------------------------------------ + @Benchmark + public void float16AddBig(Blackhole bh) { + short acc = (short)0; // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 a = Float16.shortBitsToFloat16(in1F16[i]); + Float16 b = Float16.shortBitsToFloat16(in2F16[i]); + Float16 c = Float16.shortBitsToFloat16(in3F16[i]); + Float16 val = Float16.add(Float16.multiply(a, b), + Float16.add(Float16.multiply(a, c), Float16.multiply(b, c))); + acc = Float16.float16ToRawShortBits( + Float16.add(Float16.shortBitsToFloat16(acc), val)); + } + bh.consume(acc); + } + + @Benchmark + public void float16MulBig(Blackhole bh) { + short acc = Float.floatToFloat16(1.0f); // neutral element + for (int i = 0; i < SIZE; i++) { + Float16 a = Float16.shortBitsToFloat16(in1F16[i]); + Float16 b = Float16.shortBitsToFloat16(in2F16[i]); + Float16 c = Float16.shortBitsToFloat16(in3F16[i]); + Float16 val = Float16.add(Float16.multiply(a, b), + Float16.add(Float16.multiply(a, c), Float16.multiply(b, c))); + acc = Float16.float16ToRawShortBits( + Float16.multiply(Float16.shortBitsToFloat16(acc), val)); + } + bh.consume(acc); + } + + @Fork(value = 1, jvmArgs = {"--add-modules=jdk.incubator.vector", "-XX:+UseSuperWord"}) public static class WithSuperword extends VectorReduction2 {} - @Fork(value = 1, jvmArgs = {"-XX:-UseSuperWord"}) + @Fork(value = 1, jvmArgs = {"--add-modules=jdk.incubator.vector", "-XX:-UseSuperWord"}) public static class NoSuperword extends VectorReduction2 {} } -