diff --git a/src/java.base/share/classes/java/lang/foreign/MemoryLayout.java b/src/java.base/share/classes/java/lang/foreign/MemoryLayout.java index eaac5cb936d..44a620ea3c9 100644 --- a/src/java.base/share/classes/java/lang/foreign/MemoryLayout.java +++ b/src/java.base/share/classes/java/lang/foreign/MemoryLayout.java @@ -704,12 +704,12 @@ public sealed interface MemoryLayout permits SequenceLayout, GroupLayout, Paddin * @param elementLayout the sequence element layout. * @return the new sequence layout with the given element layout and size. * @throws IllegalArgumentException if {@code elementCount } is negative. - * @throws IllegalArgumentException if {@code elementLayout.bitAlignment() > elementLayout.bitSize()}. + * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}. */ static SequenceLayout sequenceLayout(long elementCount, MemoryLayout elementLayout) { MemoryLayoutUtil.requireNonNegative(elementCount); Objects.requireNonNull(elementLayout); - Utils.checkElementAlignment(elementLayout, "Element layout alignment greater than its size"); + Utils.checkElementAlignment(elementLayout, "Element layout size is not multiple of alignment"); return wrapOverflow(() -> SequenceLayoutImpl.of(elementCount, elementLayout)); } @@ -725,7 +725,7 @@ public sealed interface MemoryLayout permits SequenceLayout, GroupLayout, Paddin * * @param elementLayout the sequence element layout. * @return a new sequence layout with the given element layout and maximum element count. - * @throws IllegalArgumentException if {@code elementLayout.bitAlignment() > elementLayout.bitSize()}. + * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}. */ static SequenceLayout sequenceLayout(MemoryLayout elementLayout) { Objects.requireNonNull(elementLayout); diff --git a/src/java.base/share/classes/java/lang/foreign/MemorySegment.java b/src/java.base/share/classes/java/lang/foreign/MemorySegment.java index 1580787431f..becb3c05f0e 100644 --- a/src/java.base/share/classes/java/lang/foreign/MemorySegment.java +++ b/src/java.base/share/classes/java/lang/foreign/MemorySegment.java @@ -459,10 +459,11 @@ public sealed interface MemorySegment permits AbstractMemorySegmentImpl { * * @param elementLayout the layout to be used for splitting. * @return the element spliterator for this segment - * @throws IllegalArgumentException if the {@code elementLayout} size is zero, or the segment size modulo the - * {@code elementLayout} size is greater than zero, if this segment is - * incompatible with the alignment constraint in the provided layout, - * or if the {@code elementLayout} alignment is greater than its size. + * @throws IllegalArgumentException if {@code elementLayout.byteSize() == 0}. + * @throws IllegalArgumentException if {@code byteSize() % elementLayout.byteSize() != 0}. + * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}. + * @throws IllegalArgumentException if this segment is incompatible + * with the alignment constraint in the provided layout. */ Spliterator spliterator(MemoryLayout elementLayout); @@ -475,10 +476,11 @@ public sealed interface MemorySegment permits AbstractMemorySegmentImpl { * * @param elementLayout the layout to be used for splitting. * @return a sequential {@code Stream} over disjoint slices in this segment. - * @throws IllegalArgumentException if the {@code elementLayout} size is zero, or the segment size modulo the - * {@code elementLayout} size is greater than zero, if this segment is - * incompatible with the alignment constraint in the provided layout, - * or if the {@code elementLayout} alignment is greater than its size. + * @throws IllegalArgumentException if {@code elementLayout.byteSize() == 0}. + * @throws IllegalArgumentException if {@code byteSize() % elementLayout.byteSize() != 0}. + * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}. + * @throws IllegalArgumentException if this segment is incompatible + * with the alignment constraint in the provided layout. */ Stream elements(MemoryLayout elementLayout); diff --git a/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java b/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java index 021aeed43c9..9a316cd5825 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java +++ b/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java @@ -165,7 +165,7 @@ public abstract sealed class AbstractMemorySegmentImpl if (elementLayout.byteSize() == 0) { throw new IllegalArgumentException("Element layout size cannot be zero"); } - Utils.checkElementAlignment(elementLayout, "Element layout alignment greater than its size"); + Utils.checkElementAlignment(elementLayout, "Element layout size is not multiple of alignment"); if (!isAlignedForElement(0, elementLayout)) { throw new IllegalArgumentException("Incompatible alignment constraints"); } diff --git a/src/java.base/share/classes/jdk/internal/foreign/Utils.java b/src/java.base/share/classes/jdk/internal/foreign/Utils.java index e6eda5130ec..6638a095265 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/Utils.java +++ b/src/java.base/share/classes/jdk/internal/foreign/Utils.java @@ -174,12 +174,22 @@ public final class Utils { } @ForceInline - public static void checkElementAlignment(MemoryLayout layout, String msg) { + public static void checkElementAlignment(ValueLayout layout, String msg) { + // Fast-path: if both size and alignment are powers of two, we can just + // check if one is greater than the other. + assert isPowerOfTwo(layout.bitSize()); if (layout.byteAlignment() > layout.byteSize()) { throw new IllegalArgumentException(msg); } } + @ForceInline + public static void checkElementAlignment(MemoryLayout layout, String msg) { + if (layout.byteSize() % layout.byteAlignment() != 0) { + throw new IllegalArgumentException(msg); + } + } + public static long pointeeByteSize(AddressLayout addressLayout) { return addressLayout.targetLayout() .map(MemoryLayout::byteSize) @@ -245,4 +255,8 @@ public final class Utils { public static int byteWidthOfPrimitive(Class primitive) { return Wrapper.forPrimitiveType(primitive).bitWidth() / 8; } + + public static boolean isPowerOfTwo(long value) { + return (value & (value - 1)) == 0L; + } } diff --git a/src/java.base/share/classes/jdk/internal/foreign/layout/AbstractLayout.java b/src/java.base/share/classes/jdk/internal/foreign/layout/AbstractLayout.java index ee109546d63..32c5edb555b 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/layout/AbstractLayout.java +++ b/src/java.base/share/classes/jdk/internal/foreign/layout/AbstractLayout.java @@ -25,6 +25,8 @@ */ package jdk.internal.foreign.layout; +import jdk.internal.foreign.Utils; + import java.lang.foreign.GroupLayout; import java.lang.foreign.MemoryLayout; import java.lang.foreign.SequenceLayout; @@ -138,8 +140,8 @@ public abstract sealed class AbstractLayout & Memory } private static long requirePowerOfTwoAndGreaterOrEqualToEight(long value) { - if (((value & (value - 1)) != 0L) || // value must be a power of two - (value < 8)) { // value must be greater or equal to 8 + if (!Utils.isPowerOfTwo(value) || // value must be a power of two + value < 8) { // value must be greater or equal to 8 throw new IllegalArgumentException("Invalid alignment: " + value); } return value; diff --git a/test/jdk/java/foreign/TestLayouts.java b/test/jdk/java/foreign/TestLayouts.java index 0e75baeaa45..0a32479cd0d 100644 --- a/test/jdk/java/foreign/TestLayouts.java +++ b/test/jdk/java/foreign/TestLayouts.java @@ -31,6 +31,7 @@ import java.lang.foreign.*; import java.lang.invoke.VarHandle; import java.nio.ByteOrder; +import java.util.ArrayList; import java.util.List; import java.util.function.LongFunction; import java.util.stream.Stream; @@ -292,11 +293,46 @@ public class TestLayouts { } @Test(dataProvider="layoutsAndAlignments", expectedExceptions = IllegalArgumentException.class) - public void testBadSequence(MemoryLayout layout, long bitAlign) { + public void testBadSequenceElementAlignmentTooBig(MemoryLayout layout, long bitAlign) { layout = layout.withBitAlignment(layout.bitSize() * 2); // hyper-align MemoryLayout.sequenceLayout(layout); } + @Test(dataProvider="layoutsAndAlignments") + public void testBadSequenceElementSizeNotMultipleOfAlignment(MemoryLayout layout, long bitAlign) { + boolean shouldFail = layout.byteSize() % layout.byteAlignment() != 0; + try { + MemoryLayout.sequenceLayout(layout); + assertFalse(shouldFail); + } catch (IllegalArgumentException ex) { + assertTrue(shouldFail); + } + } + + @Test(dataProvider="layoutsAndAlignments") + public void testBadSpliteratorElementSizeNotMultipleOfAlignment(MemoryLayout layout, long bitAlign) { + boolean shouldFail = layout.byteSize() % layout.byteAlignment() != 0; + try (Arena arena = Arena.ofConfined()) { + MemorySegment segment = arena.allocate(layout); + segment.spliterator(layout); + assertFalse(shouldFail); + } catch (IllegalArgumentException ex) { + assertTrue(shouldFail); + } + } + + @Test(dataProvider="layoutsAndAlignments") + public void testBadElementsElementSizeNotMultipleOfAlignment(MemoryLayout layout, long bitAlign) { + boolean shouldFail = layout.byteSize() % layout.byteAlignment() != 0; + try (Arena arena = Arena.ofConfined()) { + MemorySegment segment = arena.allocate(layout); + segment.elements(layout); + assertFalse(shouldFail); + } catch (IllegalArgumentException ex) { + assertTrue(shouldFail); + } + } + @Test(dataProvider="layoutsAndAlignments", expectedExceptions = IllegalArgumentException.class) public void testBadStruct(MemoryLayout layout, long bitAlign) { layout = layout.withBitAlignment(layout.bitSize() * 2); // hyper-align @@ -392,25 +428,32 @@ public class TestLayouts { @DataProvider(name = "layoutsAndAlignments") public Object[][] layoutsAndAlignments() { - Object[][] layoutsAndAlignments = new Object[basicLayouts.length * 4][]; + List layoutsAndAlignments = new ArrayList<>(); int i = 0; //add basic layouts for (MemoryLayout l : basicLayouts) { - layoutsAndAlignments[i++] = new Object[] { l, l.bitAlignment() }; + layoutsAndAlignments.add(new Object[] { l, l.bitAlignment() }); } //add basic layouts wrapped in a sequence with given size for (MemoryLayout l : basicLayouts) { - layoutsAndAlignments[i++] = new Object[] { MemoryLayout.sequenceLayout(4, l), l.bitAlignment() }; + layoutsAndAlignments.add(new Object[] { MemoryLayout.sequenceLayout(4, l), l.bitAlignment() }); } //add basic layouts wrapped in a struct - for (MemoryLayout l : basicLayouts) { - layoutsAndAlignments[i++] = new Object[] { MemoryLayout.structLayout(l), l.bitAlignment() }; + for (MemoryLayout l1 : basicLayouts) { + for (MemoryLayout l2 : basicLayouts) { + if (l1.byteSize() % l2.byteAlignment() != 0) continue; // second element is not aligned, skip + long align = Math.max(l1.bitAlignment(), l2.bitAlignment()); + layoutsAndAlignments.add(new Object[]{MemoryLayout.structLayout(l1, l2), align}); + } } //add basic layouts wrapped in a union - for (MemoryLayout l : basicLayouts) { - layoutsAndAlignments[i++] = new Object[] { MemoryLayout.unionLayout(l), l.bitAlignment() }; + for (MemoryLayout l1 : basicLayouts) { + for (MemoryLayout l2 : basicLayouts) { + long align = Math.max(l1.bitAlignment(), l2.bitAlignment()); + layoutsAndAlignments.add(new Object[]{MemoryLayout.unionLayout(l1, l2), align}); + } } - return layoutsAndAlignments; + return layoutsAndAlignments.toArray(Object[][]::new); } @DataProvider(name = "groupLayouts")