diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index 47735f6bbba..9e1d1e6ca58 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -2535,6 +2535,81 @@ VStatus VLoopBody::construct() { return VStatus::make_success(); } +// Returns true if the given operation can be vectorized with "truncation" where the upper bits in the integer do not +// contribute to the result. This is true for most arithmetic operations, but false for operations such as +// leading/trailing zero count. +static bool can_subword_truncate(Node* in, const Type* type) { + if (in->is_Load() || in->is_Store() || in->is_Convert() || in->is_Phi()) { + return true; + } + + int opc = in->Opcode(); + + // If the node's base type is a subword type, check an additional set of nodes. + if (type == TypeInt::SHORT || type == TypeInt::CHAR) { + switch (opc) { + case Op_ReverseBytesS: + case Op_ReverseBytesUS: + return true; + } + } + + // Can be truncated: + switch (opc) { + case Op_AddI: + case Op_SubI: + case Op_MulI: + case Op_AndI: + case Op_OrI: + case Op_XorI: + return true; + } + +#ifdef ASSERT + // While shifts have subword vectorized forms, they require knowing the precise type of input loads so they are + // considered non-truncating. + if (VectorNode::is_shift_opcode(opc)) { + return false; + } + + // Vector nodes should not truncate. + if (type->isa_vect() != nullptr || type->isa_vectmask() != nullptr || in->is_Reduction()) { + return false; + } + + // Cannot be truncated: + switch (opc) { + case Op_AbsI: + case Op_DivI: + case Op_MinI: + case Op_MaxI: + case Op_CMoveI: + case Op_Conv2B: + case Op_RotateRight: + case Op_RotateLeft: + case Op_PopCountI: + case Op_ReverseBytesI: + case Op_ReverseI: + case Op_CountLeadingZerosI: + case Op_CountTrailingZerosI: + case Op_IsInfiniteF: + case Op_IsInfiniteD: + case Op_ExtractS: + case Op_ExtractC: + case Op_ExtractB: + return false; + default: + // If this assert is hit, that means that we need to determine if the node can be safely truncated, + // and then add it to the list of truncating nodes or the list of non-truncating ones just above. + // In product, we just return false, which is always correct. + assert(false, "Unexpected node in SuperWord truncation: %s", NodeClassNames[in->Opcode()]); + } +#endif + + // Default to disallowing vector truncation + return false; +} + void VLoopTypes::compute_vector_element_type() { #ifndef PRODUCT if (_vloop.is_trace_vector_element_type()) { @@ -2589,18 +2664,19 @@ void VLoopTypes::compute_vector_element_type() { // be vectorized if the higher order bits info is imprecise. const Type* vt = vtn; int op = in->Opcode(); - if (VectorNode::is_shift_opcode(op) || op == Op_AbsI || op == Op_ReverseBytesI) { + if (!can_subword_truncate(in, vt)) { Node* load = in->in(1); - if (load->is_Load() && + // For certain operations such as shifts and abs(), use the size of the load if it exists + if ((VectorNode::is_shift_opcode(op) || op == Op_AbsI) && load->is_Load() && _vloop.in_bb(load) && (velt_type(load)->basic_type() == T_INT)) { // Only Load nodes distinguish signed (LoadS/LoadB) and unsigned // (LoadUS/LoadUB) values. Store nodes only have one version. vt = velt_type(load); } else if (op != Op_LShiftI) { - // Widen type to int to avoid the creation of vector nodes. Note + // Widen type to the node type to avoid the creation of vector nodes. Note // that left shifts work regardless of the signedness. - vt = TypeInt::INT; + vt = container_type(in); } } set_velt_type(in, vt); diff --git a/test/hotspot/jtreg/compiler/vectorization/TestSubwordTruncation.java b/test/hotspot/jtreg/compiler/vectorization/TestSubwordTruncation.java new file mode 100644 index 00000000000..f355a0bf05f --- /dev/null +++ b/test/hotspot/jtreg/compiler/vectorization/TestSubwordTruncation.java @@ -0,0 +1,384 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package compiler.vectorization; + +import jdk.test.lib.Asserts; +import compiler.lib.ir_framework.*; +import compiler.lib.generators.*; + +/* + * @test + * @bug 8350177 + * @summary Ensure that truncation of subword vectors produces correct results + * @library /test/lib / + * @run driver compiler.vectorization.TestSubwordTruncation + */ + +public class TestSubwordTruncation { + private static final RestrictableGenerator G = Generators.G.ints(); + private static final int SIZE = 10000; + + @Setup + static Object[] setupShortArray() { + short[] arr = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + arr[i] = G.next().shortValue(); + } + + return new Object[] { arr }; + } + + + @Setup + static Object[] setupByteArray() { + byte[] arr = new byte[SIZE]; + for (int i = 0; i < SIZE; i++) { + arr[i] = G.next().byteValue(); + } + + return new Object[] { arr }; + } + + @Setup + static Object[] setupCharArray() { + char[] arr = new char[SIZE]; + for (int i = 0; i < SIZE; i++) { + arr[i] = (char) G.next().shortValue(); + } + + return new Object[] { arr }; + } + + // Shorts + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupShortArray") + public Object[] testShortLeadingZeros(short[] in) { + short[] res = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (short) Integer.numberOfLeadingZeros(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testShortLeadingZeros") + public void checkTestShortLeadingZeros(Object[] vals) { + short[] in = (short[]) vals[0]; + short[] res = (short[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + short val = (short) Integer.numberOfLeadingZeros(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupShortArray") + public Object[] testShortTrailingZeros(short[] in) { + short[] res = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (short) Integer.numberOfTrailingZeros(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testShortTrailingZeros") + public void checkTestShortTrailingZeros(Object[] vals) { + short[] in = (short[]) vals[0]; + short[] res = (short[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + short val = (short) Integer.numberOfTrailingZeros(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupShortArray") + public Object[] testShortReverse(short[] in) { + short[] res = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (short) Integer.reverse(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testShortReverse") + public void checkTestShortReverse(Object[] vals) { + short[] in = (short[]) vals[0]; + short[] res = (short[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + short val = (short) Integer.reverse(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupShortArray") + public Object[] testShortBitCount(short[] in) { + short[] res = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (short) Integer.bitCount(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testShortBitCount") + public void checkTestShortBitCount(Object[] vals) { + short[] in = (short[]) vals[0]; + short[] res = (short[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + short val = (short) Integer.bitCount(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + // Chars + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupCharArray") + public Object[] testCharLeadingZeros(char[] in) { + char[] res = new char[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (char) Integer.numberOfLeadingZeros(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testCharLeadingZeros") + public void checkTestCharLeadingZeros(Object[] vals) { + char[] in = (char[]) vals[0]; + char[] res = (char[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + char val = (char) Integer.numberOfLeadingZeros(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupCharArray") + public Object[] testCharTrailingZeros(char[] in) { + char[] res = new char[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (char) Integer.numberOfTrailingZeros(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testCharTrailingZeros") + public void checkTestCharTrailingZeros(Object[] vals) { + char[] in = (char[]) vals[0]; + char[] res = (char[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + char val = (char) Integer.numberOfTrailingZeros(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupCharArray") + public Object[] testCharReverse(char[] in) { + char[] res = new char[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (char) Integer.reverse(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testCharReverse") + public void checkTestCharReverse(Object[] vals) { + char[] in = (char[]) vals[0]; + char[] res = (char[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + char val = (char) Integer.reverse(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupCharArray") + public Object[] testCharBitCount(char[] in) { + char[] res = new char[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (char) Integer.bitCount(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testCharBitCount") + public void checkTestCharBitCount(Object[] vals) { + char[] in = (char[]) vals[0]; + char[] res = (char[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + char val = (char) Integer.bitCount(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + // Bytes + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupByteArray") + public Object[] testByteLeadingZeros(byte[] in) { + byte[] res = new byte[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (byte) Integer.numberOfLeadingZeros(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testByteLeadingZeros") + public void checkTestByteLeadingZeros(Object[] vals) { + byte[] in = (byte[]) vals[0]; + byte[] res = (byte[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + byte val = (byte) Integer.numberOfLeadingZeros(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupByteArray") + public Object[] testByteTrailingZeros(byte[] in) { + byte[] res = new byte[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (byte) Integer.numberOfTrailingZeros(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testByteTrailingZeros") + public void checkTestByteTrailingZeros(Object[] vals) { + byte[] in = (byte[]) vals[0]; + byte[] res = (byte[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + byte val = (byte) Integer.numberOfTrailingZeros(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupByteArray") + public Object[] testByteReverse(byte[] in) { + byte[] res = new byte[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (byte) Integer.reverse(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testByteReverse") + public void checkTestByteReverse(Object[] vals) { + byte[] in = (byte[]) vals[0]; + byte[] res = (byte[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + byte val = (byte) Integer.reverse(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + @Test + @IR(counts = { IRNode.STORE_VECTOR, "=0" }) + @Arguments(setup = "setupByteArray") + public Object[] testByteBitCount(byte[] in) { + byte[] res = new byte[SIZE]; + for (int i = 0; i < SIZE; i++) { + res[i] = (byte) Integer.bitCount(in[i]); + } + + return new Object[] { in, res }; + } + + @Check(test = "testByteBitCount") + public void checkTestByteBitCount(Object[] vals) { + byte[] in = (byte[]) vals[0]; + byte[] res = (byte[]) vals[1]; + + for (int i = 0; i < SIZE; i++) { + byte val = (byte) Integer.bitCount(in[i]); + if (res[i] != val) { + throw new IllegalStateException("Expected " + val + " but got " + res[i] + " for " + in[i]); + } + } + } + + + public static void main(String[] args) { + TestFramework.run(); + } +} +