8295011: EC point multiplication improvement for secp256r1

Reviewed-by: djelinski, jjiang
This commit is contained in:
Xue-Lei Andrew Fan 2022-11-22 18:19:59 +00:00
parent fb6c992f32
commit 260e4dcbfd
5 changed files with 315 additions and 104 deletions

View File

@ -206,12 +206,9 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
//
// Compute nQ (using elliptic curve arithmetic), and verify that
// nQ is the identity element.
ImmutableIntegerModuloP xElem = ops.getField().getElement(x);
ImmutableIntegerModuloP yElem = ops.getField().getElement(y);
AffinePoint affP = new AffinePoint(xElem, yElem);
byte[] order = spec.getOrder().toByteArray();
ArrayUtil.reverse(order);
Point product = ops.multiply(affP, order);
Point product = ops.multiply(key.getW(), order);
if (!ops.isNeutral(product)) {
throw new InvalidKeyException("Point has incorrect order");
}
@ -275,12 +272,8 @@ public final class ECDHKeyAgreement extends KeyAgreementSpi {
scalar.setProduct(cofactor);
int keySize =
(priv.getParams().getCurve().getField().getFieldSize() + 7) / 8;
ImmutableIntegerModuloP x =
field.getElement(pubKey.getW().getAffineX());
ImmutableIntegerModuloP y =
field.getElement(pubKey.getW().getAffineY());
Point product = ops.multiply(new AffinePoint(x, y),
scalar.asByteArray(keySize));
Point product =
ops.multiply(pubKey.getW(), scalar.asByteArray(keySize));
if (ops.isNeutral(product)) {
throw new InvalidKeyException("Product is zero");
}

View File

@ -243,9 +243,6 @@ public class ECDSAOperations {
ImmutableIntegerModuloP u1 = e.multiply(sInv);
ImmutableIntegerModuloP u2 = ri.multiply(sInv);
AffinePoint pub = new AffinePoint(field.getElement(pp.getAffineX()),
field.getElement(pp.getAffineY()));
byte[] temp1 = new byte[length];
b2a(u1, orderField, temp1);
@ -253,7 +250,7 @@ public class ECDSAOperations {
b2a(u2, orderField, temp2);
MutablePoint p1 = ecOps.multiply(basePoint, temp1);
MutablePoint p2 = ecOps.multiply(pub, temp2);
MutablePoint p2 = ecOps.multiply(pp, temp2);
ecOps.setSum(p1, p2.asAffine());
IntegerModuloP result = p1.asAffine().getX();

View File

@ -191,11 +191,7 @@ public final class ECKeyPairGenerator extends KeyPairGeneratorSpi {
int seedSize = (seedBits + 7) / 8;
byte[] privArr = generatePrivateScalar(random, ops, seedSize);
ECPoint genPoint = ecParams.getGenerator();
ImmutableIntegerModuloP x = field.getElement(genPoint.getAffineX());
ImmutableIntegerModuloP y = field.getElement(genPoint.getAffineY());
AffinePoint affGen = new AffinePoint(x, y);
Point pub = ops.multiply(affGen, privArr);
Point pub = ops.multiply(ecParams.getGenerator(), privArr);
AffinePoint affPub = pub.asAffine();
PrivateKey privateKey = new ECPrivateKeyImpl(privArr, ecParams);

View File

@ -26,6 +26,8 @@
package sun.security.ec;
import sun.security.ec.point.*;
import sun.security.util.CurveDB;
import sun.security.util.KnownOIDs;
import sun.security.util.ArrayUtil;
import sun.security.util.math.*;
import sun.security.util.math.intpoly.*;
@ -46,6 +48,10 @@ import java.util.Optional;
*/
public class ECOperations {
private static final ECOperations secp256r1Ops =
new ECOperations(IntegerPolynomialP256.ONE.getElement(
CurveDB.lookup(KnownOIDs.secp256r1.value()).getCurve().getB()),
P256OrderField.ONE);
/*
* An exception indicating a problem with an intermediate value produced
@ -72,10 +78,9 @@ public class ECOperations {
public static Optional<ECOperations> forParameters(ECParameterSpec params) {
EllipticCurve curve = params.getCurve();
if (!(curve.getField() instanceof ECFieldFp)) {
if (!(curve.getField() instanceof ECFieldFp primeField)) {
return Optional.empty();
}
ECFieldFp primeField = (ECFieldFp) curve.getField();
BigInteger three = BigInteger.valueOf(3);
if (!primeField.getP().subtract(curve.getA()).equals(three)) {
@ -193,35 +198,6 @@ public class ECOperations {
return acc == 0;
}
/*
* 4-bit branchless array lookup for projective points.
*/
private void lookup4(ProjectivePoint.Immutable[] arr, int index,
ProjectivePoint.Mutable result, IntegerModuloP zero) {
for (int i = 0; i < 16; i++) {
int xor = index ^ i;
int bit3 = (xor & 0x8) >>> 3;
int bit2 = (xor & 0x4) >>> 2;
int bit1 = (xor & 0x2) >>> 1;
int bit0 = (xor & 0x1);
int inverse = bit0 | bit1 | bit2 | bit3;
int set = 1 - inverse;
ProjectivePoint.Immutable pi = arr[i];
result.conditionalSet(pi, set);
}
}
private void double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
MutableIntegerModuloP t1, MutableIntegerModuloP t2,
MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
for (int i = 0; i < 4; i++) {
setDouble(p, t0, t1, t2, t3, t4);
}
}
/**
* Multiply an affine point by a scalar and return the result as a mutable
* point.
@ -231,58 +207,11 @@ public class ECOperations {
* @return the product
*/
public MutablePoint multiply(AffinePoint affineP, byte[] s) {
return PointMultiplier.of(this, affineP).pointMultiply(s);
}
// 4-bit windowed multiply with branchless lookup.
// The mixed addition is faster, so it is used to construct the array
// at the beginning of the operation.
IntegerFieldModuloP field = affineP.getX().getField();
ImmutableIntegerModuloP zero = field.get0();
// temporaries
MutableIntegerModuloP t0 = zero.mutable();
MutableIntegerModuloP t1 = zero.mutable();
MutableIntegerModuloP t2 = zero.mutable();
MutableIntegerModuloP t3 = zero.mutable();
MutableIntegerModuloP t4 = zero.mutable();
ProjectivePoint.Mutable result = new ProjectivePoint.Mutable(field);
result.getY().setValue(field.get1().mutable());
ProjectivePoint.Immutable[] pointMultiples =
new ProjectivePoint.Immutable[16];
// 0P is neutral---same as initial result value
pointMultiples[0] = result.fixed();
ProjectivePoint.Mutable ps = new ProjectivePoint.Mutable(field);
ps.setValue(affineP);
// 1P = P
pointMultiples[1] = ps.fixed();
// the rest are calculated using mixed point addition
for (int i = 2; i < 16; i++) {
setSum(ps, affineP, t0, t1, t2, t3, t4);
pointMultiples[i] = ps.fixed();
}
ProjectivePoint.Mutable lookupResult = ps.mutable();
for (int i = s.length - 1; i >= 0; i--) {
double4(result, t0, t1, t2, t3, t4);
int high = (0xFF & s[i]) >>> 4;
lookup4(pointMultiples, high, lookupResult, zero);
setSum(result, lookupResult, t0, t1, t2, t3, t4);
double4(result, t0, t1, t2, t3, t4);
int low = 0xF & s[i];
lookup4(pointMultiples, low, lookupResult, zero);
setSum(result, lookupResult, t0, t1, t2, t3, t4);
}
return result;
public MutablePoint multiply(ECPoint ecPoint, byte[] s) {
return PointMultiplier.of(this, ecPoint).pointMultiply(s);
}
/*
@ -404,7 +333,6 @@ public class ECOperations {
p.getZ().setProduct(t4);
t3.setProduct(t0);
p.getZ().setSum(t3);
}
/*
@ -470,7 +398,6 @@ public class ECOperations {
t3.setProduct(t0);
p.getZ().setSum(t3);
}
// The extra step in the Full Public key validation as described in
@ -486,5 +413,290 @@ public class ECOperations {
ArrayUtil.reverse(scalar);
return isNeutral(this.multiply(ap, scalar));
}
}
sealed interface PointMultiplier {
Map<ECPoint, PointMultiplier> multipliers = Map.of(
Secp256R1GeneratorMultiplier.generator,
Secp256R1GeneratorMultiplier.multiplier);
// Multiply the point by a scalar and return the result as a mutable
// point. The multiplier point is specified by the implementation of
// this interface, which could be a general EC point or EC generator
// point.
//
// Multiply the ECPoint (that is specified in the implementation) by
// a scalar and return the result as a ProjectivePoint.Mutable point.
// The point to be multiplied can be a general EC point or the
// generator of a named EC group. The scalar multiplier is an integer
// in little endian byte array representation.
ProjectivePoint.Mutable pointMultiply(byte[] scalar);
static PointMultiplier of(ECOperations ecOps, AffinePoint affPoint) {
PointMultiplier multiplier = multipliers.get(affPoint.toECPoint());
if (multiplier == null) {
multiplier = new Default(ecOps, affPoint);
}
return multiplier;
}
static PointMultiplier of(ECOperations ecOps, ECPoint ecPoint) {
PointMultiplier multiplier = multipliers.get(ecPoint);
if (multiplier == null) {
AffinePoint affPoint =
AffinePoint.fromECPoint(ecPoint, ecOps.getField());
multiplier = new Default(ecOps, affPoint);
}
return multiplier;
}
private static void lookup(
ProjectivePoint.Immutable[] ips, int index,
ProjectivePoint.Mutable result) {
for (int i = 0; i < 16; i++) {
int xor = index ^ i;
int bit3 = (xor & 0x8) >>> 3;
int bit2 = (xor & 0x4) >>> 2;
int bit1 = (xor & 0x2) >>> 1;
int bit0 = (xor & 0x1);
int inverse = bit0 | bit1 | bit2 | bit3;
int set = 1 - inverse;
ProjectivePoint.Immutable pi = ips[i];
result.conditionalSet(pi, set);
}
}
final class Default implements PointMultiplier {
private final AffinePoint affineP;
private final ECOperations ecOps;
private Default(ECOperations ecOps, AffinePoint affineP) {
this.ecOps = ecOps;
this.affineP = affineP;
}
@Override
public ProjectivePoint.Mutable pointMultiply(byte[] s) {
// 4-bit windowed multiply with branchless lookup.
// The mixed addition is faster, so it is used to construct
// the array at the beginning of the operation.
IntegerFieldModuloP field = affineP.getX().getField();
ImmutableIntegerModuloP zero = field.get0();
// temporaries
MutableIntegerModuloP t0 = zero.mutable();
MutableIntegerModuloP t1 = zero.mutable();
MutableIntegerModuloP t2 = zero.mutable();
MutableIntegerModuloP t3 = zero.mutable();
MutableIntegerModuloP t4 = zero.mutable();
ProjectivePoint.Mutable result =
new ProjectivePoint.Mutable(field);
result.getY().setValue(field.get1().mutable());
ProjectivePoint.Immutable[] pointMultiples =
new ProjectivePoint.Immutable[16];
// 0P is neutral---same as initial result value
pointMultiples[0] = result.fixed();
ProjectivePoint.Mutable ps = new ProjectivePoint.Mutable(field);
ps.setValue(affineP);
// 1P = P
pointMultiples[1] = ps.fixed();
// the rest are calculated using mixed point addition
for (int i = 2; i < 16; i++) {
ecOps.setSum(ps, affineP, t0, t1, t2, t3, t4);
pointMultiples[i] = ps.fixed();
}
ProjectivePoint.Mutable lookupResult = ps.mutable();
for (int i = s.length - 1; i >= 0; i--) {
double4(result, t0, t1, t2, t3, t4);
int high = (0xFF & s[i]) >>> 4;
lookup(pointMultiples, high, lookupResult);
ecOps.setSum(result, lookupResult, t0, t1, t2, t3, t4);
double4(result, t0, t1, t2, t3, t4);
int low = 0xF & s[i];
lookup(pointMultiples, low, lookupResult);
ecOps.setSum(result, lookupResult, t0, t1, t2, t3, t4);
}
return result;
}
private void double4(ProjectivePoint.Mutable p,
MutableIntegerModuloP t0, MutableIntegerModuloP t1,
MutableIntegerModuloP t2, MutableIntegerModuloP t3,
MutableIntegerModuloP t4) {
for (int i = 0; i < 4; i++) {
ecOps.setDouble(p, t0, t1, t2, t3, t4);
}
}
}
final class Secp256R1GeneratorMultiplier implements PointMultiplier {
private static final ECPoint generator =
CurveDB.lookup("secp256r1").getGenerator();
private static final PointMultiplier multiplier =
new Secp256R1GeneratorMultiplier();
private static final ImmutableIntegerModuloP zero =
IntegerPolynomialP256.ONE.get0();
private static final ImmutableIntegerModuloP one =
IntegerPolynomialP256.ONE.get1();
@Override
public ProjectivePoint.Mutable pointMultiply(byte[] s) {
MutableIntegerModuloP t0 = zero.mutable();
MutableIntegerModuloP t1 = zero.mutable();
MutableIntegerModuloP t2 = zero.mutable();
MutableIntegerModuloP t3 = zero.mutable();
MutableIntegerModuloP t4 = zero.mutable();
ProjectivePoint.Mutable d = new ProjectivePoint.Mutable(
zero.mutable(),
one.mutable(),
zero.mutable());
ProjectivePoint.Mutable r = d.mutable();
for (int i = 15; i >= 0; i--) {
secp256r1Ops.setDouble(d, t0, t1, t2, t3, t4);
for (int j = 3; j >= 0; j--) {
int pos = i + j * 16;
int index = (bit(s, pos + 192) << 3) |
(bit(s, pos + 128) << 2) |
(bit(s, pos + 64) << 1) |
bit(s, pos);
lookup(P256.points[j], index, r);
secp256r1Ops.setSum(d, r, t0, t1, t2, t3, t4);
}
}
return d;
}
private static int bit(byte[] k, int i) {
return (k[i >> 3] >> (i & 0x07)) & 0x01;
}
// Lazy loading of the tables.
private static final class P256 {
// Pre-computed table to speed up the point multiplication.
//
// This is a 4x16 array of ProjectivePoint.Immutable elements.
// The first row contains the following multiples of the
// generator.
//
// index | point
// --------+----------------
// 0x0000 | 0G
// 0x0001 | 1G
// 0x0002 | (2^64)G
// 0x0003 | (2^64 + 1)G
// 0x0004 | 2^128G
// 0x0005 | (2^128 + 1)G
// 0x0006 | (2^128 + 2^64)G
// 0x0007 | (2^128 + 2^64 + 1)G
// 0x0008 | 2^192G
// 0x0009 | (2^192 + 1)G
// 0x000A | (2^192 + 2^64)G
// 0x000B | (2^192 + 2^64 + 1)G
// 0x000C | (2^192 + 2^128)G
// 0x000D | (2^192 + 2^128 + 1)G
// 0x000E | (2^192 + 2^128 + 2^64)G
// 0x000F | (2^192 + 2^128 + 2^64 + 1)G
//
// For the other 3 rows, points[i][j] = 2^16 * (points[i-1][j].
private static final ProjectivePoint.Immutable[][] points;
// Generate the pre-computed tables. This block may be
// replaced with hard-coded tables in order to speed up
// the class loading.
static {
points = new ProjectivePoint.Immutable[4][16];
BigInteger[] factors = new BigInteger[] {
BigInteger.ONE,
BigInteger.TWO.pow(64),
BigInteger.TWO.pow(128),
BigInteger.TWO.pow(192)
};
BigInteger[] base = new BigInteger[16];
base[0] = BigInteger.ZERO;
base[1] = BigInteger.ONE;
base[2] = factors[1];
for (int i = 3; i < 16; i++) {
base[i] = BigInteger.ZERO;
for (int k = 0; k < 4; k++) {
if (((i >>> k) & 0x01) != 0) {
base[i] = base[i].add(factors[k]);
}
}
}
for (int d = 0; d < 4; d++) {
for (int w = 0; w < 16; w++) {
BigInteger bi = base[w];
if (d != 0) {
bi = bi.multiply(BigInteger.TWO.pow(d * 16));
}
if (w == 0) {
points[d][0] = new ProjectivePoint.Immutable(
zero.fixed(), one.fixed(), zero.fixed());
} else {
PointMultiplier multiplier = new Default(
secp256r1Ops, AffinePoint.fromECPoint(
generator, zero.getField()));
byte[] s = bi.toByteArray();
ArrayUtil.reverse(s);
ProjectivePoint.Mutable m =
multiplier.pointMultiply(s);
points[d][w] = m.setValue(m.asAffine()).fixed();
}
}
}
// Check that the tables are correctly generated.
if (ECOperations.class.desiredAssertionStatus()) {
verifyTables(base);
}
}
private static void verifyTables(BigInteger[] base) {
for (int d = 0; d < 4; d++) {
for (int w = 0; w < 16; w++) {
BigInteger bi = base[w];
if (d != 0) {
bi = bi.multiply(BigInteger.TWO.pow(d * 16));
}
if (w != 0) {
byte[] s = new byte[32];
byte[] b = bi.toByteArray();
ArrayUtil.reverse(b);
System.arraycopy(b, 0, s, 0, b.length);
ProjectivePoint.Mutable m =
multiplier.pointMultiply(s);
ProjectivePoint.Immutable v =
m.setValue(m.asAffine()).fixed();
if (!v.getX().asBigInteger().equals(
points[d][w].getX().asBigInteger()) ||
!v.getY().asBigInteger().equals(
points[d][w].getY().asBigInteger())) {
throw new RuntimeException();
}
}
}
}
}
}
}
}
}

View File

@ -25,7 +25,9 @@
package sun.security.ec.point;
import sun.security.util.math.ImmutableIntegerModuloP;
import sun.security.util.math.IntegerFieldModuloP;
import java.security.spec.ECPoint;
import java.util.Objects;
/**
@ -44,6 +46,17 @@ public class AffinePoint {
this.y = y;
}
public static AffinePoint fromECPoint(
ECPoint ecPoint, IntegerFieldModuloP field) {
return new AffinePoint(
field.getElement(ecPoint.getAffineX()),
field.getElement(ecPoint.getAffineY()));
}
public ECPoint toECPoint() {
return new ECPoint(x.asBigInteger(), y.asBigInteger());
}
public ImmutableIntegerModuloP getX() {
return x;
}