8355719: Reduce memory consumption of BigInteger.pow()

Reviewed-by: rgiulietti
This commit is contained in:
Fabio Romano 2025-05-09 17:06:33 +00:00 committed by Raffaello Giulietti
parent 601f05e06d
commit 1c5eb370b7
2 changed files with 284 additions and 102 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 1996, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1996, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -1246,6 +1246,16 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
return new BigInteger(val);
}
/**
* Constructs a BigInteger with magnitude specified by the long,
* which may not be zero, and the signum specified by the int.
*/
private BigInteger(long mag, int signum) {
assert mag != 0 && signum != 0;
this.signum = signum;
this.mag = toMagArray(mag);
}
/**
* Constructs a BigInteger with the specified value, which may not be zero.
*/
@ -1256,16 +1266,14 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
} else {
signum = 1;
}
mag = toMagArray(val);
}
int highWord = (int)(val >>> 32);
if (highWord == 0) {
mag = new int[1];
mag[0] = (int)val;
} else {
mag = new int[2];
mag[0] = highWord;
mag[1] = (int)val;
}
private static int[] toMagArray(long mag) {
int highWord = (int) (mag >>> 32);
return highWord == 0
? new int[] { (int) mag }
: new int[] { highWord, (int) mag };
}
/**
@ -2589,116 +2597,101 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
if (exponent < 0) {
throw new ArithmeticException("Negative exponent");
}
if (signum == 0) {
return (exponent == 0 ? ONE : this);
}
if (exponent == 0 || this.equals(ONE))
return ONE;
BigInteger partToSquare = this.abs();
if (signum == 0 || exponent == 1)
return this;
BigInteger base = this.abs();
final boolean negative = signum < 0 && (exponent & 1) == 1;
// Factor out powers of two from the base, as the exponentiation of
// these can be done by left shifts only.
// The remaining part can then be exponentiated faster. The
// powers of two will be multiplied back at the end.
int powersOfTwo = partToSquare.getLowestSetBit();
long bitsToShiftLong = (long)powersOfTwo * exponent;
if (bitsToShiftLong > Integer.MAX_VALUE) {
final int powersOfTwo = base.getLowestSetBit();
final long bitsToShiftLong = (long) powersOfTwo * exponent;
final int bitsToShift = (int) bitsToShiftLong;
if (bitsToShift != bitsToShiftLong) {
reportOverflow();
}
int bitsToShift = (int)bitsToShiftLong;
int remainingBits;
// Factor the powers of two out quickly by shifting right, if needed.
if (powersOfTwo > 0) {
partToSquare = partToSquare.shiftRight(powersOfTwo);
remainingBits = partToSquare.bitLength();
if (remainingBits == 1) { // Nothing left but +/- 1?
if (signum < 0 && (exponent&1) == 1) {
return NEGATIVE_ONE.shiftLeft(bitsToShift);
} else {
return ONE.shiftLeft(bitsToShift);
}
}
} else {
remainingBits = partToSquare.bitLength();
if (remainingBits == 1) { // Nothing left but +/- 1?
if (signum < 0 && (exponent&1) == 1) {
return NEGATIVE_ONE;
} else {
return ONE;
}
}
}
// Factor the powers of two out quickly by shifting right.
base = base.shiftRight(powersOfTwo);
final int remainingBits = base.bitLength();
if (remainingBits == 1) // Nothing left but +/- 1?
return (negative ? NEGATIVE_ONE : ONE).shiftLeft(bitsToShift);
// This is a quick way to approximate the size of the result,
// similar to doing log2[n] * exponent. This will give an upper bound
// of how big the result can be, and which algorithm to use.
long scaleFactor = (long)remainingBits * exponent;
final long scaleFactor = (long) remainingBits * exponent;
// Use slightly different algorithms for small and large operands.
// See if the result will safely fit into a long. (Largest 2^63-1)
if (partToSquare.mag.length == 1 && scaleFactor <= 62) {
// Small number algorithm. Everything fits into a long.
int newSign = (signum <0 && (exponent&1) == 1 ? -1 : 1);
long result = 1;
long baseToPow2 = partToSquare.mag[0] & LONG_MASK;
int workingExponent = exponent;
// Perform exponentiation using repeated squaring trick
while (workingExponent != 0) {
if ((workingExponent & 1) == 1) {
result = result * baseToPow2;
}
if ((workingExponent >>>= 1) != 0) {
baseToPow2 = baseToPow2 * baseToPow2;
}
}
// See if the result will safely fit into an unsigned long. (Largest 2^64-1)
if (scaleFactor <= Long.SIZE) {
// Small number algorithm. Everything fits into an unsigned long.
final int newSign = negative ? -1 : 1;
final long result = unsignedLongPow(base.mag[0] & LONG_MASK, exponent);
// Multiply back the powers of two (quickly, by shifting left)
if (powersOfTwo > 0) {
if (bitsToShift + scaleFactor <= 62) { // Fits in long?
return valueOf((result << bitsToShift) * newSign);
} else {
return valueOf(result*newSign).shiftLeft(bitsToShift);
}
} else {
return valueOf(result*newSign);
}
} else {
if ((long)bitLength() * exponent / Integer.SIZE > MAX_MAG_LENGTH) {
reportOverflow();
}
// Large number algorithm. This is basically identical to
// the algorithm above, but calls multiply() and square()
// which may use more efficient algorithms for large numbers.
BigInteger answer = ONE;
int workingExponent = exponent;
// Perform exponentiation using repeated squaring trick
while (workingExponent != 0) {
if ((workingExponent & 1) == 1) {
answer = answer.multiply(partToSquare);
}
if ((workingExponent >>>= 1) != 0) {
partToSquare = partToSquare.square();
}
}
// Multiply back the (exponentiated) powers of two (quickly,
// by shifting left)
if (powersOfTwo > 0) {
answer = answer.shiftLeft(bitsToShift);
}
if (signum < 0 && (exponent&1) == 1) {
return answer.negate();
} else {
return answer;
}
return bitsToShift + scaleFactor <= Long.SIZE // Fits in long?
? new BigInteger(result << bitsToShift, newSign)
: new BigInteger(result, newSign).shiftLeft(bitsToShift);
}
if ((bitLength() - 1L) * exponent >= Integer.MAX_VALUE) {
reportOverflow();
}
// Large number algorithm. This is basically identical to
// the algorithm above, but calls multiply()
// which may use more efficient algorithms for large numbers.
BigInteger answer = ONE;
final int expZeros = Integer.numberOfLeadingZeros(exponent);
int workingExp = exponent << expZeros;
// Perform exponentiation using repeated squaring trick
// The loop relies on this invariant:
// base^exponent == answer^(2^expLen) * base^(workingExp >>> (32-expLen))
for (int expLen = Integer.SIZE - expZeros; expLen > 0; expLen--) {
answer = answer.multiply(answer);
if (workingExp < 0) // leading bit is set
answer = answer.multiply(base);
workingExp <<= 1;
}
// Multiply back the (exponentiated) powers of two (quickly,
// by shifting left)
answer = answer.shiftLeft(bitsToShift);
return negative ? answer.negate() : answer;
}
/**
* Computes {@code x^n} using repeated squaring trick.
* Assumes {@code x != 0 && x^n < 2^Long.SIZE}.
*/
static long unsignedLongPow(long x, int n) {
if (x == 1L || n == 0)
return 1L;
if (x == 2L)
return 1L << n;
/*
* The method assumption means that n <= 40 here.
* Thus, the loop body executes at most 5 times.
*/
long pow = 1L;
for (; n != 1; n >>>= 1) {
if ((n & 1) != 0)
pow *= x;
x *= x;
}
return pow * x;
}
/**

View File

@ -0,0 +1,189 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.math;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.profile.GCProfiler;
import java.math.BigInteger;
import java.util.Random;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 1, time = 1)
@Measurement(iterations = 1, time = 1)
@Fork(value = 3)
public class BigIntegerPow {
private static final int TESTSIZE = 1;
private int xsExp = (1 << 20) - 1;
/* Each array entry is atmost 64 bits in size */
private BigInteger[] xsArray = new BigInteger[TESTSIZE];
private int sExp = (1 << 18) - 1;
/* Each array entry is atmost 256 bits in size */
private BigInteger[] sArray = new BigInteger[TESTSIZE];
private int mExp = (1 << 16) - 1;
/* Each array entry is atmost 1024 bits in size */
private BigInteger[] mArray = new BigInteger[TESTSIZE];
private int lExp = (1 << 14) - 1;
/* Each array entry is atmost 4096 bits in size */
private BigInteger[] lArray = new BigInteger[TESTSIZE];
private int xlExp = (1 << 12) - 1;
/* Each array entry is atmost 16384 bits in size */
private BigInteger[] xlArray = new BigInteger[TESTSIZE];
private int[] randomExps;
/*
* You can run this test via the command line:
* $ make test TEST="micro:java.math.BigIntegerPow" MICRO="OPTIONS=-prof gc"
*/
@Setup
public void setup() {
Random r = new Random(1123);
randomExps = new int[TESTSIZE];
for (int i = 0; i < TESTSIZE; i++) {
xsArray[i] = new BigInteger(64, r);
sArray[i] = new BigInteger(256, r);
mArray[i] = new BigInteger(1024, r);
lArray[i] = new BigInteger(4096, r);
xlArray[i] = new BigInteger(16384, r);
randomExps[i] = r.nextInt(1 << 12);
}
}
/** Test BigInteger.pow() with numbers long at most 64 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowXS(Blackhole bh) {
for (BigInteger xs : xsArray) {
bh.consume(xs.pow(xsExp));
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowXSRandomExps(Blackhole bh) {
int i = 0;
for (BigInteger xs : xsArray) {
bh.consume(xs.pow(randomExps[i++]));
}
}
/** Test BigInteger.pow() with numbers long at most 256 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowS(Blackhole bh) {
for (BigInteger s : sArray) {
bh.consume(s.pow(sExp));
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowSRandomExps(Blackhole bh) {
int i = 0;
for (BigInteger s : sArray) {
bh.consume(s.pow(randomExps[i++]));
}
}
/** Test BigInteger.pow() with numbers long at most 1024 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowM(Blackhole bh) {
for (BigInteger m : mArray) {
bh.consume(m.pow(mExp));
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowMRandomExps(Blackhole bh) {
int i = 0;
for (BigInteger m : mArray) {
bh.consume(m.pow(randomExps[i++]));
}
}
/** Test BigInteger.pow() with numbers long at most 4096 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowL(Blackhole bh) {
for (BigInteger l : lArray) {
bh.consume(l.pow(lExp));
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowLRandomExps(Blackhole bh) {
int i = 0;
for (BigInteger l : lArray) {
bh.consume(l.pow(randomExps[i++]));
}
}
/** Test BigInteger.pow() with numbers long at most 16384 bits */
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowXL(Blackhole bh) {
for (BigInteger xl : xlArray) {
bh.consume(xl.pow(xlExp));
}
}
@Benchmark
@OperationsPerInvocation(TESTSIZE)
public void testPowXLRandomExps(Blackhole bh) {
int i = 0;
for (BigInteger xl : xlArray) {
bh.consume(xl.pow(randomExps[i++]));
}
}
}