8341402: BigDecimal's square root optimization

Reviewed-by: rgiulietti
This commit is contained in:
fabioromano1 2025-04-10 14:34:52 +00:00 committed by Raffaello Giulietti
parent 7e69b98e05
commit c4c3edfa96
2 changed files with 268 additions and 229 deletions

View File

@ -2145,247 +2145,177 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
* @since 9
*/
public BigDecimal sqrt(MathContext mc) {
int signum = signum();
if (signum == 1) {
/*
* The following code draws on the algorithm presented in
* "Properly Rounded Variable Precision Square Root," Hull and
* Abrham, ACM Transactions on Mathematical Software, Vol 11,
* No. 3, September 1985, Pages 229-237.
*
* The BigDecimal computational model differs from the one
* presented in the paper in several ways: first BigDecimal
* numbers aren't necessarily normalized, second many more
* rounding modes are supported, including UNNECESSARY, and
* exact results can be requested.
*
* The main steps of the algorithm below are as follows,
* first argument reduce the value to the numerical range
* [1, 10) using the following relations:
*
* x = y * 10 ^ exp
* sqrt(x) = sqrt(y) * 10^(exp / 2) if exp is even
* sqrt(x) = sqrt(y/10) * 10 ^((exp+1)/2) is exp is odd
*
* Then use Newton's iteration on the reduced value to compute
* the numerical digits of the desired result.
*
* Finally, scale back to the desired exponent range and
* perform any adjustment to get the preferred scale in the
* representation.
*/
// The code below favors relative simplicity over checking
// for special cases that could run faster.
int preferredScale = this.scale()/2;
BigDecimal zeroWithFinalPreferredScale = valueOf(0L, preferredScale);
// First phase of numerical normalization, strip trailing
// zeros and check for even powers of 10.
BigDecimal stripped = this.stripTrailingZeros();
int strippedScale = stripped.scale();
// Numerically sqrt(10^2N) = 10^N
if (stripped.isPowerOfTen() &&
strippedScale % 2 == 0) {
BigDecimal result = valueOf(1L, strippedScale/2);
if (result.scale() != preferredScale) {
// Adjust to requested precision and preferred
// scale as appropriate.
result = result.add(zeroWithFinalPreferredScale, mc);
}
final int signum = signum();
if (signum != 1) {
switch (signum) {
case -1 -> throw new ArithmeticException("Attempted square root of negative BigDecimal");
case 0 -> {
BigDecimal result = valueOf(0L, scale/2);
assert squareRootResultAssertions(result, mc);
return result;
}
default -> throw new AssertionError("Bad value from signum");
}
}
/*
* The main steps of the algorithm below are as follows,
* first argument reduce the value to an integer
* using the following relations:
*
* x = y * 10 ^ exp
* sqrt(x) = sqrt(y) * 10^(exp / 2) if exp is even
* sqrt(x) = sqrt(y*10) * 10^((exp-1)/2) is exp is odd
*
* Then use BigInteger.sqrt() on the reduced value to compute
* the numerical digits of the desired result.
*
* Finally, scale back to the desired exponent range and
* perform any adjustment to get the preferred scale in the
* representation.
*/
// The code below favors relative simplicity over checking
// for special cases that could run faster.
final int preferredScale = this.scale/2;
BigDecimal result;
if (mc.roundingMode == RoundingMode.UNNECESSARY || mc.precision == 0) { // Exact result requested
// To avoid trailing zeros in the result, strip trailing zeros.
final BigDecimal stripped = this.stripTrailingZeros();
final int strippedScale = stripped.scale;
if ((strippedScale & 1) != 0) // 10*stripped.unscaledValue() can't be an exact square
throw new ArithmeticException("Computed square root not exact.");
// Check for even powers of 10. Numerically sqrt(10^2N) = 10^N
if (stripped.isPowerOfTen()) {
result = valueOf(1L, strippedScale >> 1);
// Adjust to requested precision and preferred
// scale as appropriate.
return result.adjustToPreferredScale(preferredScale, mc.precision);
}
// After stripTrailingZeros, the representation is normalized as
//
// unscaledValue * 10^(-scale)
//
// where unscaledValue is an integer with the minimum
// precision for the cohort of the numerical value. To
// allow binary floating-point hardware to be used to get
// approximately a 15 digit approximation to the square
// root, it is helpful to instead normalize this so that
// the significand portion is to right of the decimal
// point by roughly (scale() - precision() + 1).
// precision for the cohort of the numerical value and the scale is even.
BigInteger[] sqrtRem = stripped.unscaledValue().sqrtAndRemainder();
result = new BigDecimal(sqrtRem[0], strippedScale >> 1);
// Now the precision / scale adjustment
int scaleAdjust = 0;
int scale = stripped.scale() - stripped.precision() + 1;
if (scale % 2 == 0) {
scaleAdjust = scale;
} else {
scaleAdjust = scale - 1;
}
BigDecimal working = stripped.scaleByPowerOfTen(scaleAdjust);
assert // Verify 0.1 <= working < 10
ONE_TENTH.compareTo(working) <= 0 && working.compareTo(TEN) < 0;
// Use good ole' Math.sqrt to get the initial guess for
// the Newton iteration, good to at least 15 decimal
// digits. This approach does incur the cost of a
//
// BigDecimal -> double -> BigDecimal
//
// conversion cycle, but it avoids the need for several
// Newton iterations in BigDecimal arithmetic to get the
// working answer to 15 digits of precision. If many fewer
// than 15 digits were needed, it might be faster to do
// the loop entirely in BigDecimal arithmetic.
//
// (A double value might have as many as 17 decimal
// digits of precision; it depends on the relative density
// of binary and decimal numbers at different regions of
// the number line.)
//
// (It would be possible to check for certain special
// cases to avoid doing any Newton iterations. For
// example, if the BigDecimal -> double conversion was
// known to be exact and the rounding mode had a
// low-enough precision, the post-Newton rounding logic
// could be applied directly.)
BigDecimal guess = new BigDecimal(Math.sqrt(working.doubleValue()));
int guessPrecision = 15;
int originalPrecision = mc.getPrecision();
int targetPrecision;
// If an exact value is requested, it must only need about
// half of the input digits to represent since multiplying
// an N digit number by itself yield a 2N-1 digit or 2N
// digit result.
if (originalPrecision == 0) {
targetPrecision = stripped.precision()/2 + 1;
} else {
/*
* To avoid the need for post-Newton fix-up logic, in
* the case of half-way rounding modes, double the
* target precision so that the "2p + 2" property can
* be relied on to accomplish the final rounding.
*/
switch (mc.getRoundingMode()) {
case HALF_UP:
case HALF_DOWN:
case HALF_EVEN:
targetPrecision = 2 * originalPrecision;
if (targetPrecision < 0) // Overflow
targetPrecision = Integer.MAX_VALUE - 2;
break;
default:
targetPrecision = originalPrecision;
break;
}
}
// When setting the precision to use inside the Newton
// iteration loop, take care to avoid the case where the
// precision of the input exceeds the requested precision
// and rounding the input value too soon.
BigDecimal approx = guess;
int workingPrecision = working.precision();
do {
int tmpPrecision = Math.max(Math.max(guessPrecision, targetPrecision + 2),
workingPrecision);
MathContext mcTmp = new MathContext(tmpPrecision, RoundingMode.HALF_EVEN);
// approx = 0.5 * (approx + fraction / approx)
approx = ONE_HALF.multiply(approx.add(working.divide(approx, mcTmp), mcTmp));
guessPrecision *= 2;
} while (guessPrecision < targetPrecision + 2);
BigDecimal result;
RoundingMode targetRm = mc.getRoundingMode();
if (targetRm == RoundingMode.UNNECESSARY || originalPrecision == 0) {
RoundingMode tmpRm =
(targetRm == RoundingMode.UNNECESSARY) ? RoundingMode.DOWN : targetRm;
MathContext mcTmp = new MathContext(targetPrecision, tmpRm);
result = approx.scaleByPowerOfTen(-scaleAdjust/2).round(mcTmp);
// If result*result != this numerically, the square
// root isn't exact
if (this.subtract(result.square()).compareTo(ZERO) != 0) {
throw new ArithmeticException("Computed square root not exact.");
}
} else {
result = approx.scaleByPowerOfTen(-scaleAdjust/2).round(mc);
switch (targetRm) {
case DOWN:
case FLOOR:
// Check if too big
if (result.square().compareTo(this) > 0) {
BigDecimal ulp = result.ulp();
// Adjust increment down in case of 1.0 = 10^0
// since the next smaller number is only 1/10
// as far way as the next larger at exponent
// boundaries. Test approx and *not* result to
// avoid having to detect an arbitrary power
// of ten.
if (approx.compareTo(ONE) == 0) {
ulp = ulp.multiply(ONE_TENTH);
}
result = result.subtract(ulp);
}
break;
case UP:
case CEILING:
// Check if too small
if (result.square().compareTo(this) < 0) {
result = result.add(result.ulp());
}
break;
default:
// No additional work, rely on "2p + 2" property
// for correct rounding. Alternatively, could
// instead run the Newton iteration to around p
// digits and then do tests and fix-ups on the
// rounded value. One possible set of tests and
// fix-ups is given in the Hull and Abrham paper;
// however, additional half-way cases can occur
// for BigDecimal given the more varied
// combinations of input and output precisions
// supported.
break;
}
}
// If result*result != this numerically or requires too high precision,
// the square root isn't exact
if (sqrtRem[1].signum != 0 || mc.precision != 0 && result.precision() > mc.precision)
throw new ArithmeticException("Computed square root not exact.");
// Test numerical properties at full precision before any
// scale adjustments.
assert squareRootResultAssertions(result, mc);
if (result.scale() != preferredScale) {
// The preferred scale of an add is
// max(addend.scale(), augend.scale()). Therefore, if
// the scale of the result is first minimized using
// stripTrailingZeros(), adding a zero of the
// preferred scale rounding to the correct precision
// will perform the proper scale vs precision
// tradeoffs.
result = result.stripTrailingZeros().
add(zeroWithFinalPreferredScale,
new MathContext(originalPrecision, RoundingMode.UNNECESSARY));
}
return result;
} else {
BigDecimal result = null;
switch (signum) {
case -1:
throw new ArithmeticException("Attempted square root " +
"of negative BigDecimal");
case 0:
result = valueOf(0L, scale()/2);
assert squareRootResultAssertions(result, mc);
return result;
// Adjust to requested precision and preferred
// scale as appropriate.
return result.adjustToPreferredScale(preferredScale, mc.precision);
}
// To allow BigInteger.sqrt() to be used to get the square
// root, it is necessary to normalize the input so that
// its integer part is sufficient to get the square root
// with the desired precision.
default:
throw new AssertionError("Bad value from signum");
final boolean halfWay = isHalfWay(mc.roundingMode);
// To obtain a square root with N digits,
// the radicand must have at least 2*(N-1)+1 == 2*N-1 digits.
final long minWorkingPrec = ((mc.precision + (halfWay ? 1L : 0L)) << 1) - 1L;
// normScale is the number of digits to take from the fraction of the input
long normScale = minWorkingPrec - this.precision() + this.scale;
normScale += normScale & 1L; // the scale for normalizing must be even
final long workingScale = this.scale - normScale;
if (workingScale != (int) workingScale)
throw new ArithmeticException("Overflow");
BigDecimal working = new BigDecimal(this.intVal, this.intCompact, (int) workingScale, this.precision);
BigInteger workingInt = working.toBigInteger();
BigInteger sqrt;
long resultScale = normScale >> 1;
// Round sqrt with the specified settings
if (halfWay) { // half-way rounding
BigInteger workingSqrt = workingInt.sqrt();
// remove the one-tenth digit
BigInteger[] quotRem10 = workingSqrt.divideAndRemainder(BigInteger.TEN);
sqrt = quotRem10[0];
resultScale--;
boolean increment = false;
int digit = quotRem10[1].intValue();
if (digit > 5) {
increment = true;
} else if (digit == 5) {
if (mc.roundingMode == RoundingMode.HALF_UP
|| mc.roundingMode == RoundingMode.HALF_EVEN && sqrt.testBit(0)
// Check if remainder is non-zero
|| !workingInt.equals(workingSqrt.multiply(workingSqrt))
|| !working.isInteger()) {
increment = true;
}
}
if (increment)
sqrt = sqrt.add(1L);
} else {
switch (mc.roundingMode) {
case DOWN, FLOOR -> sqrt = workingInt.sqrt(); // No need to round
case UP, CEILING -> {
BigInteger[] sqrtRem = workingInt.sqrtAndRemainder();
sqrt = sqrtRem[0];
// Check if remainder is non-zero
if (sqrtRem[1].signum != 0 || !working.isInteger())
sqrt = sqrt.add(1L);
}
default -> throw new AssertionError("Unexpected value for RoundingMode: " + mc.roundingMode);
}
}
result = new BigDecimal(sqrt, checkScale(sqrt, resultScale), mc); // mc ensures no increase of precision
// Test numerical properties at full precision before any
// scale adjustments.
assert squareRootResultAssertions(result, mc);
// Adjust to requested precision and preferred
// scale as appropriate.
if (result.scale > preferredScale) // else can't increase the result's precision to fit the preferred scale
result = stripZerosToMatchScale(result.intVal, result.intCompact, result.scale, preferredScale);
return result;
}
/**
* Assumes {@code (precision() <= maxPrecision || maxPrecision == 0) && this != 0}.
* @param preferredScale the scale to reach
* @param maxPrecision the largest precision the result can have.
* {@code maxPrecision == 0} means that the result can have arbitrary precision.
* @return a BigDecimal numerically equivalent to {@code this}, whose precision
* does not exceed {@code maxPrecision} and whose scale is the closest
* to {@code preferredScale}.
*/
private BigDecimal adjustToPreferredScale(int preferredScale, int maxPrecision) {
BigDecimal result = this;
if (result.scale > preferredScale) {
result = stripZerosToMatchScale(result.intVal, result.intCompact, result.scale, preferredScale);
} else if (result.scale < preferredScale) {
int maxScale = maxPrecision == 0 ?
preferredScale : (int) Math.min(preferredScale, result.scale + (long) (maxPrecision - result.precision()));
result = result.setScale(maxScale);
}
return result;
}
private static boolean isHalfWay(RoundingMode m) {
return switch (m) {
case HALF_DOWN, HALF_UP, HALF_EVEN -> true;
case FLOOR, CEILING, DOWN, UP, UNNECESSARY -> false;
};
}
private BigDecimal square() {
@ -3553,6 +3483,19 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
return buf.toString();
}
/**
* @return {@code true} if and only if {@code this == this.toBigInteger()}
*/
boolean isInteger() {
if (scale <= 0 || signum() == 0)
return true;
// Get an upper bound of precision() without using big powers of 10 (see bigDigitLength())
int digitLen = precision != 0 ? precision
: (intCompact != INFLATED ? precision() : (digitLengthLower(unscaledValue()) + 1));
return digitLen > scale && stripZerosToMatchScale(intVal, intCompact, scale, 0L).scale == 0;
}
/**
* Converts this {@code BigDecimal} to a {@code BigInteger}.
* This conversion is analogous to the
@ -4602,8 +4545,15 @@ public class BigDecimal extends Number implements Comparable<BigDecimal> {
*/
if (b.signum == 0)
return 1;
int r = (int)((((long)b.bitLength() + 1) * 646456993) >>> 31);
return b.compareMagnitude(bigTenToThe(r)) < 0? r : r+1;
int r = digitLengthLower(b);
return b.compareMagnitude(bigTenToThe(r)) < 0 ? r : r + 1;
}
/**
* @return an integer {@code r} such that {@code 10^(r-1) <= abs(b) < 10^(r+1)}.
*/
private static int digitLengthLower(BigInteger b) {
return (int) (((b.abs().bitLength() + 1L) * 646456993L) >>> 31);
}
/**

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2016, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -23,7 +23,7 @@
/*
* @test
* @bug 4851777 8233452
* @bug 4851777 8233452 8341402
* @summary Tests of BigDecimal.sqrt().
*/
@ -60,6 +60,8 @@ public class SquareRootTests {
failures += nearTen();
failures += nearOne();
failures += halfWay();
failures += exactResultTests();
failures += scaleOverflowTest();
if (failures > 0 ) {
throw new RuntimeException("Incurred " + failures + " failures" +
@ -171,6 +173,93 @@ public class SquareRootTests {
bd.sqrt(mc), "sqrt(" + bd + ") under " + mc);
}
private static int exactResultTests() {
int failures = 0;
MathContext unnecessary = new MathContext(1, RoundingMode.UNNECESSARY);
MathContext arbitrary = new MathContext(0, RoundingMode.CEILING);
BigDecimal[] errCases = {
// (strippedScale & 1) != 0
BigDecimal.TEN,
// (strippedScale & 1) == 0 && !stripped.isPowerOfTen() && sqrtRem[1].signum != 0
BigDecimal.TWO,
};
for (BigDecimal input : errCases) {
BigDecimal result;
// mc.roundingMode == RoundingMode.UNNECESSARY
try {
result = input.sqrt(unnecessary);
System.err.println("Unexpected sqrt with UNNECESSARY RoundingMode: (" + input + ").sqrt() = " + result);
failures += 1;
} catch (ArithmeticException e) {
// Expected
}
// mc.roundingMode != RoundingMode.UNNECESSARY && mc.precision == 0
try {
result = input.sqrt(arbitrary);
System.err.println("Unexpected sqrt with mc.precision == 0: (" + input + ").sqrt() = " + result);
failures += 1;
} catch (ArithmeticException e) {
// Expected
}
}
// (strippedScale & 1) == 0
// !stripped.isPowerOfTen() && sqrtRem[1].signum == 0 && (mc.precision != 0 && result.precision() > mc.precision)
try {
BigDecimal input = BigDecimal.valueOf(121);
BigDecimal result = input.sqrt(unnecessary);
System.err.println("Unexpected sqrt with result.precision() > mc.precision: ("
+ input + ").sqrt() = " + result);
failures += 1;
} catch (ArithmeticException e) {
// Expected
}
BigDecimal four = BigDecimal.valueOf(4);
Object[][] cases = {
// stripped.isPowerOfTen() && mc.roundingMode == RoundingMode.UNNECESSARY
{ BigDecimal.ONE, unnecessary, BigDecimal.ONE },
// stripped.isPowerOfTen() && mc.roundingMode != RoundingMode.UNNECESSARY && mc.precision == 0
{ BigDecimal.ONE, arbitrary, BigDecimal.ONE },
// !stripped.isPowerOfTen() && mc.roundingMode == RoundingMode.UNNECESSARY
// && sqrtRem[1].signum == 0 && mc.precision == 0
{ four, new MathContext(0, RoundingMode.UNNECESSARY), BigDecimal.TWO },
// !stripped.isPowerOfTen() && mc.roundingMode != RoundingMode.UNNECESSARY
// && sqrtRem[1].signum == 0 && mc.precision == 0
{ four, arbitrary, BigDecimal.TWO },
// !stripped.isPowerOfTen() && sqrtRem[1].signum == 0
// && (mc.precision != 0 && result.precision() <= mc.precision)
{ four, unnecessary, BigDecimal.TWO },
};
for (Object[] testCase : cases) {
BigDecimal expected = (BigDecimal) testCase[2];
BigDecimal result = ((BigDecimal) testCase[0]).sqrt((MathContext) testCase[1]);
failures += compare(expected, result, true, "Exact results");
}
return failures;
}
private static int scaleOverflowTest() {
int failures = 0;
try {
BigDecimal.valueOf(1, -1).sqrt(new MathContext((1 << 30) + 1, RoundingMode.UP));
System.err.println("ArithmeticException expected: possible overflow undetected "
+ "or the range of supported values for the algorithm has extended.");
failures += 1;
} catch (ArithmeticException e) {
// Expected
}
return failures;
}
/**
* sqrt(10^2N) is 10^N
* Both numerical value and representation should be verified