mirror of
https://github.com/openjdk/jdk.git
synced 2026-01-28 12:09:14 +00:00
8355719: Reduce memory consumption of BigInteger.pow()
Reviewed-by: rgiulietti
This commit is contained in:
parent
601f05e06d
commit
1c5eb370b7
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 1996, 2024, Oracle and/or its affiliates. All rights reserved.
|
||||
* Copyright (c) 1996, 2025, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
@ -1246,6 +1246,16 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
|
||||
return new BigInteger(val);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a BigInteger with magnitude specified by the long,
|
||||
* which may not be zero, and the signum specified by the int.
|
||||
*/
|
||||
private BigInteger(long mag, int signum) {
|
||||
assert mag != 0 && signum != 0;
|
||||
this.signum = signum;
|
||||
this.mag = toMagArray(mag);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a BigInteger with the specified value, which may not be zero.
|
||||
*/
|
||||
@ -1256,16 +1266,14 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
|
||||
} else {
|
||||
signum = 1;
|
||||
}
|
||||
mag = toMagArray(val);
|
||||
}
|
||||
|
||||
int highWord = (int)(val >>> 32);
|
||||
if (highWord == 0) {
|
||||
mag = new int[1];
|
||||
mag[0] = (int)val;
|
||||
} else {
|
||||
mag = new int[2];
|
||||
mag[0] = highWord;
|
||||
mag[1] = (int)val;
|
||||
}
|
||||
private static int[] toMagArray(long mag) {
|
||||
int highWord = (int) (mag >>> 32);
|
||||
return highWord == 0
|
||||
? new int[] { (int) mag }
|
||||
: new int[] { highWord, (int) mag };
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2589,116 +2597,101 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
|
||||
if (exponent < 0) {
|
||||
throw new ArithmeticException("Negative exponent");
|
||||
}
|
||||
if (signum == 0) {
|
||||
return (exponent == 0 ? ONE : this);
|
||||
}
|
||||
if (exponent == 0 || this.equals(ONE))
|
||||
return ONE;
|
||||
|
||||
BigInteger partToSquare = this.abs();
|
||||
if (signum == 0 || exponent == 1)
|
||||
return this;
|
||||
|
||||
BigInteger base = this.abs();
|
||||
final boolean negative = signum < 0 && (exponent & 1) == 1;
|
||||
|
||||
// Factor out powers of two from the base, as the exponentiation of
|
||||
// these can be done by left shifts only.
|
||||
// The remaining part can then be exponentiated faster. The
|
||||
// powers of two will be multiplied back at the end.
|
||||
int powersOfTwo = partToSquare.getLowestSetBit();
|
||||
long bitsToShiftLong = (long)powersOfTwo * exponent;
|
||||
if (bitsToShiftLong > Integer.MAX_VALUE) {
|
||||
final int powersOfTwo = base.getLowestSetBit();
|
||||
final long bitsToShiftLong = (long) powersOfTwo * exponent;
|
||||
final int bitsToShift = (int) bitsToShiftLong;
|
||||
if (bitsToShift != bitsToShiftLong) {
|
||||
reportOverflow();
|
||||
}
|
||||
int bitsToShift = (int)bitsToShiftLong;
|
||||
|
||||
int remainingBits;
|
||||
|
||||
// Factor the powers of two out quickly by shifting right, if needed.
|
||||
if (powersOfTwo > 0) {
|
||||
partToSquare = partToSquare.shiftRight(powersOfTwo);
|
||||
remainingBits = partToSquare.bitLength();
|
||||
if (remainingBits == 1) { // Nothing left but +/- 1?
|
||||
if (signum < 0 && (exponent&1) == 1) {
|
||||
return NEGATIVE_ONE.shiftLeft(bitsToShift);
|
||||
} else {
|
||||
return ONE.shiftLeft(bitsToShift);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
remainingBits = partToSquare.bitLength();
|
||||
if (remainingBits == 1) { // Nothing left but +/- 1?
|
||||
if (signum < 0 && (exponent&1) == 1) {
|
||||
return NEGATIVE_ONE;
|
||||
} else {
|
||||
return ONE;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Factor the powers of two out quickly by shifting right.
|
||||
base = base.shiftRight(powersOfTwo);
|
||||
final int remainingBits = base.bitLength();
|
||||
if (remainingBits == 1) // Nothing left but +/- 1?
|
||||
return (negative ? NEGATIVE_ONE : ONE).shiftLeft(bitsToShift);
|
||||
|
||||
// This is a quick way to approximate the size of the result,
|
||||
// similar to doing log2[n] * exponent. This will give an upper bound
|
||||
// of how big the result can be, and which algorithm to use.
|
||||
long scaleFactor = (long)remainingBits * exponent;
|
||||
final long scaleFactor = (long) remainingBits * exponent;
|
||||
|
||||
// Use slightly different algorithms for small and large operands.
|
||||
// See if the result will safely fit into a long. (Largest 2^63-1)
|
||||
if (partToSquare.mag.length == 1 && scaleFactor <= 62) {
|
||||
// Small number algorithm. Everything fits into a long.
|
||||
int newSign = (signum <0 && (exponent&1) == 1 ? -1 : 1);
|
||||
long result = 1;
|
||||
long baseToPow2 = partToSquare.mag[0] & LONG_MASK;
|
||||
|
||||
int workingExponent = exponent;
|
||||
|
||||
// Perform exponentiation using repeated squaring trick
|
||||
while (workingExponent != 0) {
|
||||
if ((workingExponent & 1) == 1) {
|
||||
result = result * baseToPow2;
|
||||
}
|
||||
|
||||
if ((workingExponent >>>= 1) != 0) {
|
||||
baseToPow2 = baseToPow2 * baseToPow2;
|
||||
}
|
||||
}
|
||||
// See if the result will safely fit into an unsigned long. (Largest 2^64-1)
|
||||
if (scaleFactor <= Long.SIZE) {
|
||||
// Small number algorithm. Everything fits into an unsigned long.
|
||||
final int newSign = negative ? -1 : 1;
|
||||
final long result = unsignedLongPow(base.mag[0] & LONG_MASK, exponent);
|
||||
|
||||
// Multiply back the powers of two (quickly, by shifting left)
|
||||
if (powersOfTwo > 0) {
|
||||
if (bitsToShift + scaleFactor <= 62) { // Fits in long?
|
||||
return valueOf((result << bitsToShift) * newSign);
|
||||
} else {
|
||||
return valueOf(result*newSign).shiftLeft(bitsToShift);
|
||||
}
|
||||
} else {
|
||||
return valueOf(result*newSign);
|
||||
}
|
||||
} else {
|
||||
if ((long)bitLength() * exponent / Integer.SIZE > MAX_MAG_LENGTH) {
|
||||
reportOverflow();
|
||||
}
|
||||
|
||||
// Large number algorithm. This is basically identical to
|
||||
// the algorithm above, but calls multiply() and square()
|
||||
// which may use more efficient algorithms for large numbers.
|
||||
BigInteger answer = ONE;
|
||||
|
||||
int workingExponent = exponent;
|
||||
// Perform exponentiation using repeated squaring trick
|
||||
while (workingExponent != 0) {
|
||||
if ((workingExponent & 1) == 1) {
|
||||
answer = answer.multiply(partToSquare);
|
||||
}
|
||||
|
||||
if ((workingExponent >>>= 1) != 0) {
|
||||
partToSquare = partToSquare.square();
|
||||
}
|
||||
}
|
||||
// Multiply back the (exponentiated) powers of two (quickly,
|
||||
// by shifting left)
|
||||
if (powersOfTwo > 0) {
|
||||
answer = answer.shiftLeft(bitsToShift);
|
||||
}
|
||||
|
||||
if (signum < 0 && (exponent&1) == 1) {
|
||||
return answer.negate();
|
||||
} else {
|
||||
return answer;
|
||||
}
|
||||
return bitsToShift + scaleFactor <= Long.SIZE // Fits in long?
|
||||
? new BigInteger(result << bitsToShift, newSign)
|
||||
: new BigInteger(result, newSign).shiftLeft(bitsToShift);
|
||||
}
|
||||
|
||||
if ((bitLength() - 1L) * exponent >= Integer.MAX_VALUE) {
|
||||
reportOverflow();
|
||||
}
|
||||
|
||||
// Large number algorithm. This is basically identical to
|
||||
// the algorithm above, but calls multiply()
|
||||
// which may use more efficient algorithms for large numbers.
|
||||
BigInteger answer = ONE;
|
||||
|
||||
final int expZeros = Integer.numberOfLeadingZeros(exponent);
|
||||
int workingExp = exponent << expZeros;
|
||||
// Perform exponentiation using repeated squaring trick
|
||||
// The loop relies on this invariant:
|
||||
// base^exponent == answer^(2^expLen) * base^(workingExp >>> (32-expLen))
|
||||
for (int expLen = Integer.SIZE - expZeros; expLen > 0; expLen--) {
|
||||
answer = answer.multiply(answer);
|
||||
if (workingExp < 0) // leading bit is set
|
||||
answer = answer.multiply(base);
|
||||
|
||||
workingExp <<= 1;
|
||||
}
|
||||
|
||||
// Multiply back the (exponentiated) powers of two (quickly,
|
||||
// by shifting left)
|
||||
answer = answer.shiftLeft(bitsToShift);
|
||||
return negative ? answer.negate() : answer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes {@code x^n} using repeated squaring trick.
|
||||
* Assumes {@code x != 0 && x^n < 2^Long.SIZE}.
|
||||
*/
|
||||
static long unsignedLongPow(long x, int n) {
|
||||
if (x == 1L || n == 0)
|
||||
return 1L;
|
||||
|
||||
if (x == 2L)
|
||||
return 1L << n;
|
||||
|
||||
/*
|
||||
* The method assumption means that n <= 40 here.
|
||||
* Thus, the loop body executes at most 5 times.
|
||||
*/
|
||||
long pow = 1L;
|
||||
for (; n != 1; n >>>= 1) {
|
||||
if ((n & 1) != 0)
|
||||
pow *= x;
|
||||
|
||||
x *= x;
|
||||
}
|
||||
return pow * x;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
189
test/micro/org/openjdk/bench/java/math/BigIntegerPow.java
Normal file
189
test/micro/org/openjdk/bench/java/math/BigIntegerPow.java
Normal file
@ -0,0 +1,189 @@
|
||||
/*
|
||||
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
package org.openjdk.bench.java.math;
|
||||
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
import org.openjdk.jmh.annotations.Fork;
|
||||
import org.openjdk.jmh.annotations.Measurement;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.annotations.OperationsPerInvocation;
|
||||
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
import org.openjdk.jmh.annotations.Warmup;
|
||||
import org.openjdk.jmh.infra.Blackhole;
|
||||
import org.openjdk.jmh.runner.Runner;
|
||||
import org.openjdk.jmh.runner.RunnerException;
|
||||
import org.openjdk.jmh.runner.options.Options;
|
||||
import org.openjdk.jmh.runner.options.OptionsBuilder;
|
||||
import org.openjdk.jmh.profile.GCProfiler;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.NANOSECONDS)
|
||||
@State(Scope.Thread)
|
||||
@Warmup(iterations = 1, time = 1)
|
||||
@Measurement(iterations = 1, time = 1)
|
||||
@Fork(value = 3)
|
||||
public class BigIntegerPow {
|
||||
|
||||
private static final int TESTSIZE = 1;
|
||||
|
||||
private int xsExp = (1 << 20) - 1;
|
||||
/* Each array entry is atmost 64 bits in size */
|
||||
private BigInteger[] xsArray = new BigInteger[TESTSIZE];
|
||||
|
||||
private int sExp = (1 << 18) - 1;
|
||||
/* Each array entry is atmost 256 bits in size */
|
||||
private BigInteger[] sArray = new BigInteger[TESTSIZE];
|
||||
|
||||
private int mExp = (1 << 16) - 1;
|
||||
/* Each array entry is atmost 1024 bits in size */
|
||||
private BigInteger[] mArray = new BigInteger[TESTSIZE];
|
||||
|
||||
private int lExp = (1 << 14) - 1;
|
||||
/* Each array entry is atmost 4096 bits in size */
|
||||
private BigInteger[] lArray = new BigInteger[TESTSIZE];
|
||||
|
||||
private int xlExp = (1 << 12) - 1;
|
||||
/* Each array entry is atmost 16384 bits in size */
|
||||
private BigInteger[] xlArray = new BigInteger[TESTSIZE];
|
||||
|
||||
private int[] randomExps;
|
||||
|
||||
/*
|
||||
* You can run this test via the command line:
|
||||
* $ make test TEST="micro:java.math.BigIntegerPow" MICRO="OPTIONS=-prof gc"
|
||||
*/
|
||||
|
||||
@Setup
|
||||
public void setup() {
|
||||
Random r = new Random(1123);
|
||||
|
||||
randomExps = new int[TESTSIZE];
|
||||
for (int i = 0; i < TESTSIZE; i++) {
|
||||
xsArray[i] = new BigInteger(64, r);
|
||||
sArray[i] = new BigInteger(256, r);
|
||||
mArray[i] = new BigInteger(1024, r);
|
||||
lArray[i] = new BigInteger(4096, r);
|
||||
xlArray[i] = new BigInteger(16384, r);
|
||||
randomExps[i] = r.nextInt(1 << 12);
|
||||
}
|
||||
}
|
||||
|
||||
/** Test BigInteger.pow() with numbers long at most 64 bits */
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowXS(Blackhole bh) {
|
||||
for (BigInteger xs : xsArray) {
|
||||
bh.consume(xs.pow(xsExp));
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowXSRandomExps(Blackhole bh) {
|
||||
int i = 0;
|
||||
for (BigInteger xs : xsArray) {
|
||||
bh.consume(xs.pow(randomExps[i++]));
|
||||
}
|
||||
}
|
||||
|
||||
/** Test BigInteger.pow() with numbers long at most 256 bits */
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowS(Blackhole bh) {
|
||||
for (BigInteger s : sArray) {
|
||||
bh.consume(s.pow(sExp));
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowSRandomExps(Blackhole bh) {
|
||||
int i = 0;
|
||||
for (BigInteger s : sArray) {
|
||||
bh.consume(s.pow(randomExps[i++]));
|
||||
}
|
||||
}
|
||||
|
||||
/** Test BigInteger.pow() with numbers long at most 1024 bits */
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowM(Blackhole bh) {
|
||||
for (BigInteger m : mArray) {
|
||||
bh.consume(m.pow(mExp));
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowMRandomExps(Blackhole bh) {
|
||||
int i = 0;
|
||||
for (BigInteger m : mArray) {
|
||||
bh.consume(m.pow(randomExps[i++]));
|
||||
}
|
||||
}
|
||||
|
||||
/** Test BigInteger.pow() with numbers long at most 4096 bits */
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowL(Blackhole bh) {
|
||||
for (BigInteger l : lArray) {
|
||||
bh.consume(l.pow(lExp));
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowLRandomExps(Blackhole bh) {
|
||||
int i = 0;
|
||||
for (BigInteger l : lArray) {
|
||||
bh.consume(l.pow(randomExps[i++]));
|
||||
}
|
||||
}
|
||||
|
||||
/** Test BigInteger.pow() with numbers long at most 16384 bits */
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowXL(Blackhole bh) {
|
||||
for (BigInteger xl : xlArray) {
|
||||
bh.consume(xl.pow(xlExp));
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@OperationsPerInvocation(TESTSIZE)
|
||||
public void testPowXLRandomExps(Blackhole bh) {
|
||||
int i = 0;
|
||||
for (BigInteger xl : xlArray) {
|
||||
bh.consume(xl.pow(randomExps[i++]));
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user