8350468: x86: Improve implementation of vectorized numberOfLeadingZeros for int and long

Co-authored-by: Raffaello Giulietti <rgiulietti@openjdk.org>
Reviewed-by: sviswanathan, qamai, vlivanov
This commit is contained in:
Jasmine Karthikeyan 2025-11-10 06:16:02 +00:00
parent 4e4cced710
commit f77a5117db
3 changed files with 208 additions and 65 deletions

View File

@ -6119,77 +6119,64 @@ void C2_MacroAssembler::vector_count_leading_zeros_short_avx(XMMRegister dst, XM
void C2_MacroAssembler::vector_count_leading_zeros_int_avx(XMMRegister dst, XMMRegister src, XMMRegister xtmp1,
XMMRegister xtmp2, XMMRegister xtmp3, int vec_enc) {
// Since IEEE 754 floating point format represents mantissa in 1.0 format
// hence biased exponent can be used to compute leading zero count as per
// following formula:-
// LZCNT = 31 - (biased_exp - 127)
// Special handling has been introduced for Zero, Max_Int and -ve source values.
// Broadcast 0xFF
vpcmpeqd(xtmp1, xtmp1, xtmp1, vec_enc);
vpsrld(xtmp1, xtmp1, 24, vec_enc);
// By converting the integer to a float, we can obtain the number of leading zeros based on the exponent of the float.
// As the float exponent contains a bias of 127 for nonzero values, the bias must be removed before interpreting the
// exponent as the leading zero count.
// Remove the bit to the right of the highest set bit ensuring that the conversion to float cannot round up to a higher
// power of 2, which has a higher exponent than the input. This transformation is valid as only the highest set bit
// contributes to the leading number of zeros.
vpsrld(xtmp2, src, 1, vec_enc);
vpandn(xtmp3, xtmp2, src, vec_enc);
vpsrld(dst, src, 1, vec_enc);
vpandn(dst, dst, src, vec_enc);
// Extract biased exponent.
vcvtdq2ps(dst, xtmp3, vec_enc);
vcvtdq2ps(dst, dst, vec_enc);
// By comparing the register to itself, all the bits in the destination are set.
vpcmpeqd(xtmp1, xtmp1, xtmp1, vec_enc);
// Move the biased exponent to the low end of the lane and mask with 0xFF to discard the sign bit.
vpsrld(xtmp2, xtmp1, 24, vec_enc);
vpsrld(dst, dst, 23, vec_enc);
vpand(dst, dst, xtmp1, vec_enc);
vpand(dst, xtmp2, dst, vec_enc);
// Broadcast 127.
vpsrld(xtmp1, xtmp1, 1, vec_enc);
// Exponent = biased_exp - 127
vpsubd(dst, dst, xtmp1, vec_enc);
// Subtract 127 from the exponent, which removes the bias from the exponent.
vpsrld(xtmp2, xtmp1, 25, vec_enc);
vpsubd(dst, dst, xtmp2, vec_enc);
// Exponent_plus_one = Exponent + 1
vpsrld(xtmp3, xtmp1, 6, vec_enc);
vpaddd(dst, dst, xtmp3, vec_enc);
vpsrld(xtmp2, xtmp1, 27, vec_enc);
// Replace -ve exponent with zero, exponent is -ve when src
// lane contains a zero value.
vpxor(xtmp2, xtmp2, xtmp2, vec_enc);
vblendvps(dst, dst, xtmp2, dst, vec_enc);
// If the original value is 0 the exponent would not have bias, so the subtraction creates a negative number. If this
// is found in any of the lanes, replace the lane with -1 from xtmp1.
vblendvps(dst, dst, xtmp1, dst, vec_enc, true, xtmp3);
// Rematerialize broadcast 32.
vpslld(xtmp1, xtmp3, 5, vec_enc);
// Exponent is 32 if corresponding source lane contains max_int value.
vpcmpeqd(xtmp2, dst, xtmp1, vec_enc);
// LZCNT = 32 - exponent_plus_one
vpsubd(dst, xtmp1, dst, vec_enc);
// If the original value is negative, replace the lane with 31.
vblendvps(dst, dst, xtmp2, src, vec_enc, true, xtmp3);
// Replace LZCNT with a value 1 if corresponding source lane
// contains max_int value.
vpblendvb(dst, dst, xtmp3, xtmp2, vec_enc);
// Replace biased_exp with 0 if source lane value is less than zero.
vpxor(xtmp2, xtmp2, xtmp2, vec_enc);
vblendvps(dst, dst, xtmp2, src, vec_enc);
// Subtract the exponent from 31, giving the final result. For 0, the result is 32 as the exponent was replaced with -1,
// and for negative numbers the result is 0 as the exponent was replaced with 31.
vpsubd(dst, xtmp2, dst, vec_enc);
}
void C2_MacroAssembler::vector_count_leading_zeros_long_avx(XMMRegister dst, XMMRegister src, XMMRegister xtmp1,
XMMRegister xtmp2, XMMRegister xtmp3, Register rtmp, int vec_enc) {
vector_count_leading_zeros_short_avx(dst, src, xtmp1, xtmp2, xtmp3, rtmp, vec_enc);
// Add zero counts of lower word and upper word of a double word if
// upper word holds a zero value.
vpsrld(xtmp3, src, 16, vec_enc);
// xtmp1 is set to all zeros by vector_count_leading_zeros_byte_avx.
vpcmpeqd(xtmp3, xtmp1, xtmp3, vec_enc);
vpslld(xtmp2, dst, 16, vec_enc);
vpaddd(xtmp2, xtmp2, dst, vec_enc);
vpblendvb(dst, dst, xtmp2, xtmp3, vec_enc);
vpsrld(dst, dst, 16, vec_enc);
// Add zero counts of lower doubleword and upper doubleword of a
// quadword if upper doubleword holds a zero value.
vpsrlq(xtmp3, src, 32, vec_enc);
vpcmpeqq(xtmp3, xtmp1, xtmp3, vec_enc);
vpsllq(xtmp2, dst, 32, vec_enc);
vpaddq(xtmp2, xtmp2, dst, vec_enc);
vpblendvb(dst, dst, xtmp2, xtmp3, vec_enc);
vpsrlq(dst, dst, 32, vec_enc);
// Find the leading zeros of the top and bottom halves of the long individually.
vector_count_leading_zeros_int_avx(dst, src, xtmp1, xtmp2, xtmp3, vec_enc);
// Move the top half result to the bottom half of xtmp1, setting the top half to 0.
vpsrlq(xtmp1, dst, 32, vec_enc);
// By moving the top half result to the right by 6 bits, if the top half was empty (i.e. 32 is returned) the result bit will
// be in the most significant position of the bottom half.
vpsrlq(xtmp2, dst, 6, vec_enc);
// In the bottom half, add the top half and bottom half results.
vpaddq(dst, xtmp1, dst, vec_enc);
// For the bottom half, choose between the values using the most significant bit of xtmp2.
// If the MSB is set, then bottom+top in dst is the resulting value. If the top half is less than 32 xtmp1 is chosen,
// which contains only the top half result.
// In the top half the MSB is always zero, so the value in xtmp1 is always chosen. This value is always 0, which clears
// the lane as required.
vblendvps(dst, xtmp1, dst, xtmp2, vec_enc, true, xtmp3);
}
void C2_MacroAssembler::vector_count_leading_zeros_avx(BasicType bt, XMMRegister dst, XMMRegister src,

View File

@ -46,7 +46,8 @@ import jdk.test.lib.Asserts;
import jdk.test.lib.Utils;
public class TestNumberOfContinuousZeros {
private static final int[] SPECIAL = { 0x01FFFFFF, 0x03FFFFFE, 0x07FFFFFC, 0x0FFFFFF8, 0x1FFFFFF0, 0x3FFFFFE0, 0xFFFFFFFF };
private static final int[] SPECIAL_INT = { 0, 0x01FFFFFF, 0x03FFFFFE, 0x07FFFFFC, 0x0FFFFFF8, 0x1FFFFFF0, 0x3FFFFFE0, 0xFFFFFFFF };
private static final long[] SPECIAL_LONG = { 0, 0xFF, 0xFFFF, 0x01FFFFFF, 0x03FFFFFE, 0x07FFFFFC, 0x0FFFFFF8, 0x1FFFFFF0, 0x3FFFFFE0, 0xFFFFFFFF, 0xFFFFFFFFFFFFFFFFL, 0x7FFFFFFFFFFFFFFFL };
private long[] inputLong;
private int[] outputLong;
private int[] inputInt;
@ -134,7 +135,18 @@ public class TestNumberOfContinuousZeros {
int[] res = new int[LEN];
for (int i = 0; i < LEN; i++) {
res[i] = SPECIAL[i % SPECIAL.length];
res[i] = SPECIAL_INT[i % SPECIAL_INT.length];
}
return new Object[] { res };
}
@Setup
static Object[] setupSpecialLongArray() {
long[] res = new long[LEN];
for (int i = 0; i < LEN; i++) {
res[i] = SPECIAL_LONG[i % SPECIAL_LONG.length];
}
return new Object[] { res };
@ -167,23 +179,51 @@ public class TestNumberOfContinuousZeros {
}
}
private static final VectorSpecies<Integer> SPECIES = IntVector.SPECIES_PREFERRED;
@Test
@IR(counts = {IRNode.COUNT_LEADING_ZEROS_VL, "> 0"})
@Arguments(setup = "setupSpecialLongArray")
public Object[] testSpecialLongLeadingZeros(long[] longs) {
int[] res = new int[LEN];
for (int i = 0; i < LEN; ++i) {
res[i] = Long.numberOfLeadingZeros(longs[i]);
}
return new Object[] { longs, res };
}
@Check(test = "testSpecialLongLeadingZeros")
public void checkSpecialLongLeadingZeros(Object[] vals) {
long[] in = (long[]) vals[0];
int[] out = (int[]) vals[1];
for (int i = 0; i < LEN; ++i) {
int value = Long.numberOfLeadingZeros(in[i]);
if (out[i] != value) {
throw new IllegalStateException("Expected lzcnt(" + in[i] + ") to be " + value + " but got " + out[i]);
}
}
}
private static final VectorSpecies<Integer> SPECIES_INT = IntVector.SPECIES_PREFERRED;
private static final VectorSpecies<Long> SPECIES_LONG = LongVector.SPECIES_PREFERRED;
@Test
@IR(counts = {IRNode.COUNT_LEADING_ZEROS_VI, "> 0"})
@Arguments(setup = "setupSpecialIntArray")
public Object[] checkSpecialIntLeadingZerosVector(int[] ints) {
public Object[] testIntLeadingZerosVector(int[] ints) {
int[] res = new int[LEN];
for (int i = 0; i < ints.length; i += SPECIES.length()) {
IntVector av = IntVector.fromArray(SPECIES, ints, i);
for (int i = 0; i < ints.length; i += SPECIES_INT.length()) {
IntVector av = IntVector.fromArray(SPECIES_INT, ints, i);
av.lanewise(VectorOperators.LEADING_ZEROS_COUNT).intoArray(res, i);
}
return new Object[] { ints, res };
}
@Check(test = "checkSpecialIntLeadingZerosVector")
@Check(test = "testIntLeadingZerosVector")
public void checkSpecialIntLeadingZerosVector(Object[] vals) {
int[] ints = (int[]) vals[0];
int[] res = (int[]) vals[1];
@ -192,8 +232,43 @@ public class TestNumberOfContinuousZeros {
int[] check = new int[LEN];
for (int i = 0; i < ints.length; i += SPECIES.length()) {
IntVector av = IntVector.fromArray(SPECIES, ints, i);
for (int i = 0; i < ints.length; i += SPECIES_INT.length()) {
IntVector av = IntVector.fromArray(SPECIES_INT, ints, i);
av.lanewise(VectorOperators.LEADING_ZEROS_COUNT).intoArray(check, i);
}
for (int i = 0; i < LEN; i++) {
if (res[i] != check[i]) {
throw new IllegalStateException("Expected " + check[i] + " but got " + res[i]);
}
}
}
@Test
@IR(counts = {IRNode.COUNT_LEADING_ZEROS_VL, "> 0"})
@Arguments(setup = "setupSpecialLongArray")
public Object[] testLongLeadingZerosVector(long[] longs) {
long[] res = new long[LEN];
for (int i = 0; i < longs.length; i += SPECIES_LONG.length()) {
LongVector av = LongVector.fromArray(SPECIES_LONG, longs, i);
av.lanewise(VectorOperators.LEADING_ZEROS_COUNT).intoArray(res, i);
}
return new Object[] { longs, res };
}
@Check(test = "testLongLeadingZerosVector")
public void checkSpecialLongLeadingZerosVector(Object[] vals) {
long[] longs = (long[]) vals[0];
long[] res = (long[]) vals[1];
// Verification
long[] check = new long[LEN];
for (int i = 0; i < longs.length; i += SPECIES_LONG.length()) {
LongVector av = LongVector.fromArray(SPECIES_LONG, longs, i);
av.lanewise(VectorOperators.LEADING_ZEROS_COUNT).intoArray(check, i);
}

View File

@ -0,0 +1,81 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.vm.compiler;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import java.util.Random;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Measurement(iterations = 5, time = 1000, timeUnit = TimeUnit.MILLISECONDS)
@Warmup(iterations = 5, time = 1000, timeUnit = TimeUnit.MILLISECONDS)
@Fork(3)
public class LeadingZeros {
private static final int SIZE = 512;
@Benchmark
public void testInt(Blackhole blackhole, BenchState state) {
for (int i = 0; i < SIZE; i++) {
state.result[i] = Integer.numberOfLeadingZeros(state.ints[i]);
}
blackhole.consume(state.result);
}
@Benchmark
public void testLong(Blackhole blackhole, BenchState state) {
for (int i = 0; i < SIZE; i++) {
state.result[i] = Long.numberOfLeadingZeros(state.longs[i]);
}
blackhole.consume(state.result);
}
@State(Scope.Benchmark)
public static class BenchState {
private final int[] ints = new int[SIZE];
private final long[] longs = new long[SIZE];
private final int[] result = new int[SIZE];
private Random random;
public BenchState() {
}
@Setup
public void setup() {
this.random = new Random(1000);
for (int i = 0; i < SIZE; i++) {
ints[i] = this.random.nextInt();
longs[i] = this.random.nextLong();
}
}
}
}