From c5e5316f2e28966e717239720557dab8ef72111d Mon Sep 17 00:00:00 2001 From: catbref Date: Sun, 8 May 2022 09:29:13 +0100 Subject: [PATCH] Schnorr public key and signature aggregation for 'online accounts'. Aggregated signature should reduce block payload significantly, as well as associated network, memory & CPU loads. org.qortal.crypto.BouncyCastle25519 renamed to Qortal25519Extras. Our class provides additional features such as DH-based shared secret, aggregating public keys & signatures and sign/verify for aggregate use. BouncyCastle's Ed25519 class copied in as BouncyCastleEd25519, but with 'private' modifiers changed to 'protected', to allow extension by our Qortal25519Extras class, and to avoid lots of messy reflection-based calls. --- .../org/qortal/crypto/BouncyCastle25519.java | 99 -- .../qortal/crypto/BouncyCastleEd25519.java | 1427 +++++++++++++++++ src/main/java/org/qortal/crypto/Crypto.java | 4 +- .../org/qortal/crypto/Qortal25519Extras.java | 234 +++ .../java/org/qortal/test/CryptoTests.java | 14 +- .../java/org/qortal/test/SchnorrTests.java | 190 +++ 6 files changed, 1860 insertions(+), 108 deletions(-) delete mode 100644 src/main/java/org/qortal/crypto/BouncyCastle25519.java create mode 100644 src/main/java/org/qortal/crypto/BouncyCastleEd25519.java create mode 100644 src/main/java/org/qortal/crypto/Qortal25519Extras.java create mode 100644 src/test/java/org/qortal/test/SchnorrTests.java diff --git a/src/main/java/org/qortal/crypto/BouncyCastle25519.java b/src/main/java/org/qortal/crypto/BouncyCastle25519.java deleted file mode 100644 index 1a2e0de9..00000000 --- a/src/main/java/org/qortal/crypto/BouncyCastle25519.java +++ /dev/null @@ -1,99 +0,0 @@ -package org.qortal.crypto; - -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.Arrays; - -import org.bouncycastle.crypto.Digest; -import org.bouncycastle.math.ec.rfc7748.X25519; -import org.bouncycastle.math.ec.rfc7748.X25519Field; -import org.bouncycastle.math.ec.rfc8032.Ed25519; - -/** Additions to BouncyCastle providing Ed25519 to X25519 key conversion. */ -public class BouncyCastle25519 { - - private static final Class pointAffineClass; - private static final Constructor pointAffineCtor; - private static final Method decodePointVarMethod; - private static final Field yField; - - static { - try { - Class ed25519Class = Ed25519.class; - pointAffineClass = Arrays.stream(ed25519Class.getDeclaredClasses()).filter(clazz -> clazz.getSimpleName().equals("PointAffine")).findFirst().get(); - if (pointAffineClass == null) - throw new ClassNotFoundException("Can't locate PointExt inner class inside Ed25519"); - - decodePointVarMethod = ed25519Class.getDeclaredMethod("decodePointVar", byte[].class, int.class, boolean.class, pointAffineClass); - decodePointVarMethod.setAccessible(true); - - pointAffineCtor = pointAffineClass.getDeclaredConstructors()[0]; - pointAffineCtor.setAccessible(true); - - yField = pointAffineClass.getDeclaredField("y"); - yField.setAccessible(true); - } catch (NoSuchMethodException | SecurityException | IllegalArgumentException | NoSuchFieldException | ClassNotFoundException e) { - throw new RuntimeException("Can't initialize BouncyCastle25519 shim", e); - } - } - - private static int[] obtainYFromPublicKey(byte[] ed25519PublicKey) { - try { - Object pA = pointAffineCtor.newInstance(); - - Boolean result = (Boolean) decodePointVarMethod.invoke(null, ed25519PublicKey, 0, true, pA); - if (result == null || !result) - return null; - - return (int[]) yField.get(pA); - } catch (SecurityException | InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { - throw new RuntimeException("Can't reflect into BouncyCastle", e); - } - } - - public static byte[] toX25519PublicKey(byte[] ed25519PublicKey) { - int[] one = new int[X25519Field.SIZE]; - X25519Field.one(one); - - int[] y = obtainYFromPublicKey(ed25519PublicKey); - - int[] oneMinusY = new int[X25519Field.SIZE]; - X25519Field.sub(one, y, oneMinusY); - - int[] onePlusY = new int[X25519Field.SIZE]; - X25519Field.add(one, y, onePlusY); - - int[] oneMinusYInverted = new int[X25519Field.SIZE]; - X25519Field.inv(oneMinusY, oneMinusYInverted); - - int[] u = new int[X25519Field.SIZE]; - X25519Field.mul(onePlusY, oneMinusYInverted, u); - - X25519Field.normalize(u); - - byte[] x25519PublicKey = new byte[X25519.SCALAR_SIZE]; - X25519Field.encode(u, x25519PublicKey, 0); - - return x25519PublicKey; - } - - public static byte[] toX25519PrivateKey(byte[] ed25519PrivateKey) { - Digest d = Ed25519.createPrehash(); - byte[] h = new byte[d.getDigestSize()]; - - d.update(ed25519PrivateKey, 0, ed25519PrivateKey.length); - d.doFinal(h, 0); - - byte[] s = new byte[X25519.SCALAR_SIZE]; - - System.arraycopy(h, 0, s, 0, X25519.SCALAR_SIZE); - s[0] &= 0xF8; - s[X25519.SCALAR_SIZE - 1] &= 0x7F; - s[X25519.SCALAR_SIZE - 1] |= 0x40; - - return s; - } - -} diff --git a/src/main/java/org/qortal/crypto/BouncyCastleEd25519.java b/src/main/java/org/qortal/crypto/BouncyCastleEd25519.java new file mode 100644 index 00000000..ebcf0f97 --- /dev/null +++ b/src/main/java/org/qortal/crypto/BouncyCastleEd25519.java @@ -0,0 +1,1427 @@ +package org.qortal.crypto; + +import java.security.SecureRandom; + +import org.bouncycastle.crypto.Digest; +import org.bouncycastle.crypto.digests.SHA512Digest; +import org.bouncycastle.math.ec.rfc7748.X25519; +import org.bouncycastle.math.ec.rfc7748.X25519Field; +import org.bouncycastle.math.raw.Interleave; +import org.bouncycastle.math.raw.Nat; +import org.bouncycastle.math.raw.Nat256; +import org.bouncycastle.util.Arrays; + +/** + * Duplicate of {@link org.bouncycastle.math.ec.rfc8032.Ed25519}, + * but with {@code private} modifiers replaced with {@code protected}, + * to allow for extension by {@link org.qortal.crypto.Qortal25519Extras}. + */ +public abstract class BouncyCastleEd25519 +{ + // -x^2 + y^2 == 1 + 0x52036CEE2B6FFE738CC740797779E89800700A4D4141D8AB75EB4DCA135978A3 * x^2 * y^2 + + public static final class Algorithm + { + public static final int Ed25519 = 0; + public static final int Ed25519ctx = 1; + public static final int Ed25519ph = 2; + } + + protected static class F extends X25519Field {}; + + protected static final long M08L = 0x000000FFL; + protected static final long M28L = 0x0FFFFFFFL; + protected static final long M32L = 0xFFFFFFFFL; + + protected static final int POINT_BYTES = 32; + protected static final int SCALAR_INTS = 8; + protected static final int SCALAR_BYTES = SCALAR_INTS * 4; + + public static final int PREHASH_SIZE = 64; + public static final int PUBLIC_KEY_SIZE = POINT_BYTES; + public static final int SECRET_KEY_SIZE = 32; + public static final int SIGNATURE_SIZE = POINT_BYTES + SCALAR_BYTES; + + // "SigEd25519 no Ed25519 collisions" + protected static final byte[] DOM2_PREFIX = new byte[]{ 0x53, 0x69, 0x67, 0x45, 0x64, 0x32, 0x35, 0x35, 0x31, 0x39, + 0x20, 0x6e, 0x6f, 0x20, 0x45, 0x64, 0x32, 0x35, 0x35, 0x31, 0x39, 0x20, 0x63, 0x6f, 0x6c, 0x6c, 0x69, 0x73, + 0x69, 0x6f, 0x6e, 0x73 }; + + protected static final int[] P = new int[]{ 0xFFFFFFED, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0x7FFFFFFF }; + protected static final int[] L = new int[]{ 0x5CF5D3ED, 0x5812631A, 0xA2F79CD6, 0x14DEF9DE, 0x00000000, 0x00000000, + 0x00000000, 0x10000000 }; + + protected static final int L0 = 0xFCF5D3ED; // L0:26/-- + protected static final int L1 = 0x012631A6; // L1:24/22 + protected static final int L2 = 0x079CD658; // L2:27/-- + protected static final int L3 = 0xFF9DEA2F; // L3:23/-- + protected static final int L4 = 0x000014DF; // L4:12/11 + + protected static final int[] B_x = new int[]{ 0x0325D51A, 0x018B5823, 0x007B2C95, 0x0304A92D, 0x00D2598E, 0x01D6DC5C, + 0x01388C7F, 0x013FEC0A, 0x029E6B72, 0x0042D26D }; + protected static final int[] B_y = new int[]{ 0x02666658, 0x01999999, 0x00666666, 0x03333333, 0x00CCCCCC, 0x02666666, + 0x01999999, 0x00666666, 0x03333333, 0x00CCCCCC, }; + protected static final int[] C_d = new int[]{ 0x035978A3, 0x02D37284, 0x018AB75E, 0x026A0A0E, 0x0000E014, 0x0379E898, + 0x01D01E5D, 0x01E738CC, 0x03715B7F, 0x00A406D9 }; + protected static final int[] C_d2 = new int[]{ 0x02B2F159, 0x01A6E509, 0x01156EBD, 0x00D4141D, 0x0001C029, 0x02F3D130, + 0x03A03CBB, 0x01CE7198, 0x02E2B6FF, 0x00480DB3 }; + protected static final int[] C_d4 = new int[]{ 0x0165E2B2, 0x034DCA13, 0x002ADD7A, 0x01A8283B, 0x00038052, 0x01E7A260, + 0x03407977, 0x019CE331, 0x01C56DFF, 0x00901B67 }; + + protected static final int WNAF_WIDTH_BASE = 7; + + protected static final int PRECOMP_BLOCKS = 8; + protected static final int PRECOMP_TEETH = 4; + protected static final int PRECOMP_SPACING = 8; + protected static final int PRECOMP_POINTS = 1 << (PRECOMP_TEETH - 1); + protected static final int PRECOMP_MASK = PRECOMP_POINTS - 1; + + protected static final Object precompLock = new Object(); + // TODO[ed25519] Convert to PointPrecomp + protected static PointExt[] precompBaseTable = null; + protected static int[] precompBase = null; + + protected static class PointAccum + { + int[] x = F.create(); + int[] y = F.create(); + int[] z = F.create(); + int[] u = F.create(); + int[] v = F.create(); + } + + protected static class PointAffine + { + int[] x = F.create(); + int[] y = F.create(); + } + + protected static class PointExt + { + int[] x = F.create(); + int[] y = F.create(); + int[] z = F.create(); + int[] t = F.create(); + } + + protected static class PointPrecomp + { + int[] ypx_h = F.create(); + int[] ymx_h = F.create(); + int[] xyd = F.create(); + } + + protected static byte[] calculateS(byte[] r, byte[] k, byte[] s) + { + int[] t = new int[SCALAR_INTS * 2]; decodeScalar(r, 0, t); + int[] u = new int[SCALAR_INTS]; decodeScalar(k, 0, u); + int[] v = new int[SCALAR_INTS]; decodeScalar(s, 0, v); + + Nat256.mulAddTo(u, v, t); + + byte[] result = new byte[SCALAR_BYTES * 2]; + for (int i = 0; i < t.length; ++i) + { + encode32(t[i], result, i * 4); + } + return reduceScalar(result); + } + + protected static boolean checkContextVar(byte[] ctx , byte phflag) + { + return ctx == null && phflag == 0x00 + || ctx != null && ctx.length < 256; + } + + protected static int checkPoint(int[] x, int[] y) + { + int[] t = F.create(); + int[] u = F.create(); + int[] v = F.create(); + + F.sqr(x, u); + F.sqr(y, v); + F.mul(u, v, t); + F.sub(v, u, v); + F.mul(t, C_d, t); + F.addOne(t); + F.sub(t, v, t); + F.normalize(t); + + return F.isZero(t); + } + + protected static int checkPoint(int[] x, int[] y, int[] z) + { + int[] t = F.create(); + int[] u = F.create(); + int[] v = F.create(); + int[] w = F.create(); + + F.sqr(x, u); + F.sqr(y, v); + F.sqr(z, w); + F.mul(u, v, t); + F.sub(v, u, v); + F.mul(v, w, v); + F.sqr(w, w); + F.mul(t, C_d, t); + F.add(t, w, t); + F.sub(t, v, t); + F.normalize(t); + + return F.isZero(t); + } + + protected static boolean checkPointVar(byte[] p) + { + int[] t = new int[8]; + decode32(p, 0, t, 0, 8); + t[7] &= 0x7FFFFFFF; + return !Nat256.gte(t, P); + } + + protected static boolean checkScalarVar(byte[] s) + { + int[] n = new int[SCALAR_INTS]; + decodeScalar(s, 0, n); + return !Nat256.gte(n, L); + } + + protected static Digest createDigest() + { + return new SHA512Digest(); + } + + public static Digest createPrehash() + { + return createDigest(); + } + + protected static int decode24(byte[] bs, int off) + { + int n = bs[ off] & 0xFF; + n |= (bs[++off] & 0xFF) << 8; + n |= (bs[++off] & 0xFF) << 16; + return n; + } + + protected static int decode32(byte[] bs, int off) + { + int n = bs[off] & 0xFF; + n |= (bs[++off] & 0xFF) << 8; + n |= (bs[++off] & 0xFF) << 16; + n |= bs[++off] << 24; + return n; + } + + protected static void decode32(byte[] bs, int bsOff, int[] n, int nOff, int nLen) + { + for (int i = 0; i < nLen; ++i) + { + n[nOff + i] = decode32(bs, bsOff + i * 4); + } + } + + protected static boolean decodePointVar(byte[] p, int pOff, boolean negate, PointAffine r) + { + byte[] py = Arrays.copyOfRange(p, pOff, pOff + POINT_BYTES); + if (!checkPointVar(py)) + { + return false; + } + + int x_0 = (py[POINT_BYTES - 1] & 0x80) >>> 7; + py[POINT_BYTES - 1] &= 0x7F; + + F.decode(py, 0, r.y); + + int[] u = F.create(); + int[] v = F.create(); + + F.sqr(r.y, u); + F.mul(C_d, u, v); + F.subOne(u); + F.addOne(v); + + if (!F.sqrtRatioVar(u, v, r.x)) + { + return false; + } + + F.normalize(r.x); + if (x_0 == 1 && F.isZeroVar(r.x)) + { + return false; + } + + if (negate ^ (x_0 != (r.x[0] & 1))) + { + F.negate(r.x, r.x); + } + + return true; + } + + protected static void decodeScalar(byte[] k, int kOff, int[] n) + { + decode32(k, kOff, n, 0, SCALAR_INTS); + } + + protected static void dom2(Digest d, byte phflag, byte[] ctx) + { + if (ctx != null) + { + int n = DOM2_PREFIX.length; + byte[] t = new byte[n + 2 + ctx.length]; + System.arraycopy(DOM2_PREFIX, 0, t, 0, n); + t[n] = phflag; + t[n + 1] = (byte)ctx.length; + System.arraycopy(ctx, 0, t, n + 2, ctx.length); + + d.update(t, 0, t.length); + } + } + + protected static void encode24(int n, byte[] bs, int off) + { + bs[ off] = (byte)(n ); + bs[++off] = (byte)(n >>> 8); + bs[++off] = (byte)(n >>> 16); + } + + protected static void encode32(int n, byte[] bs, int off) + { + bs[ off] = (byte)(n ); + bs[++off] = (byte)(n >>> 8); + bs[++off] = (byte)(n >>> 16); + bs[++off] = (byte)(n >>> 24); + } + + protected static void encode56(long n, byte[] bs, int off) + { + encode32((int)n, bs, off); + encode24((int)(n >>> 32), bs, off + 4); + } + + protected static int encodePoint(PointAccum p, byte[] r, int rOff) + { + int[] x = F.create(); + int[] y = F.create(); + + F.inv(p.z, y); + F.mul(p.x, y, x); + F.mul(p.y, y, y); + F.normalize(x); + F.normalize(y); + + int result = checkPoint(x, y); + + F.encode(y, r, rOff); + r[rOff + POINT_BYTES - 1] |= ((x[0] & 1) << 7); + + return result; + } + + public static void generatePrivateKey(SecureRandom random, byte[] k) + { + random.nextBytes(k); + } + + public static void generatePublicKey(byte[] sk, int skOff, byte[] pk, int pkOff) + { + Digest d = createDigest(); + byte[] h = new byte[d.getDigestSize()]; + + d.update(sk, skOff, SECRET_KEY_SIZE); + d.doFinal(h, 0); + + byte[] s = new byte[SCALAR_BYTES]; + pruneScalar(h, 0, s); + + scalarMultBaseEncoded(s, pk, pkOff); + } + + protected static int getWindow4(int[] x, int n) + { + int w = n >>> 3, b = (n & 7) << 2; + return (x[w] >>> b) & 15; + } + + protected static byte[] getWnafVar(int[] n, int width) + { +// assert n[SCALAR_INTS - 1] >>> 28 == 0; + + int[] t = new int[SCALAR_INTS * 2]; + { + int tPos = t.length, c = 0; + int i = SCALAR_INTS; + while (--i >= 0) + { + int next = n[i]; + t[--tPos] = (next >>> 16) | (c << 16); + t[--tPos] = c = next; + } + } + + byte[] ws = new byte[253]; + + final int pow2 = 1 << width; + final int mask = pow2 - 1; + final int sign = pow2 >>> 1; + + int j = 0, carry = 0; + for (int i = 0; i < t.length; ++i, j -= 16) + { + int word = t[i]; + while (j < 16) + { + int word16 = word >>> j; + int bit = word16 & 1; + + if (bit == carry) + { + ++j; + continue; + } + + int digit = (word16 & mask) + carry; + carry = digit & sign; + digit -= (carry << 1); + carry >>>= (width - 1); + + ws[(i << 4) + j] = (byte)digit; + + j += width; + } + } + +// assert carry == 0; + + return ws; + } + + protected static void implSign(Digest d, byte[] h, byte[] s, byte[] pk, int pkOff, byte[] ctx, byte phflag, byte[] m, + int mOff, int mLen, byte[] sig, int sigOff) + { + dom2(d, phflag, ctx); + d.update(h, SCALAR_BYTES, SCALAR_BYTES); + d.update(m, mOff, mLen); + d.doFinal(h, 0); + + byte[] r = reduceScalar(h); + byte[] R = new byte[POINT_BYTES]; + scalarMultBaseEncoded(r, R, 0); + + dom2(d, phflag, ctx); + d.update(R, 0, POINT_BYTES); + d.update(pk, pkOff, POINT_BYTES); + d.update(m, mOff, mLen); + d.doFinal(h, 0); + + byte[] k = reduceScalar(h); + byte[] S = calculateS(r, k, s); + + System.arraycopy(R, 0, sig, sigOff, POINT_BYTES); + System.arraycopy(S, 0, sig, sigOff + POINT_BYTES, SCALAR_BYTES); + } + + protected static void implSign(byte[] sk, int skOff, byte[] ctx, byte phflag, byte[] m, int mOff, int mLen, + byte[] sig, int sigOff) + { + if (!checkContextVar(ctx, phflag)) + { + throw new IllegalArgumentException("ctx"); + } + + Digest d = createDigest(); + byte[] h = new byte[d.getDigestSize()]; + + d.update(sk, skOff, SECRET_KEY_SIZE); + d.doFinal(h, 0); + + byte[] s = new byte[SCALAR_BYTES]; + pruneScalar(h, 0, s); + + byte[] pk = new byte[POINT_BYTES]; + scalarMultBaseEncoded(s, pk, 0); + + implSign(d, h, s, pk, 0, ctx, phflag, m, mOff, mLen, sig, sigOff); + } + + protected static void implSign(byte[] sk, int skOff, byte[] pk, int pkOff, byte[] ctx, byte phflag, + byte[] m, int mOff, int mLen, byte[] sig, int sigOff) + { + if (!checkContextVar(ctx, phflag)) + { + throw new IllegalArgumentException("ctx"); + } + + Digest d = createDigest(); + byte[] h = new byte[d.getDigestSize()]; + + d.update(sk, skOff, SECRET_KEY_SIZE); + d.doFinal(h, 0); + + byte[] s = new byte[SCALAR_BYTES]; + pruneScalar(h, 0, s); + + implSign(d, h, s, pk, pkOff, ctx, phflag, m, mOff, mLen, sig, sigOff); + } + + protected static boolean implVerify(byte[] sig, int sigOff, byte[] pk, int pkOff, byte[] ctx, byte phflag, byte[] m, + int mOff, int mLen) + { + if (!checkContextVar(ctx, phflag)) + { + throw new IllegalArgumentException("ctx"); + } + + byte[] R = Arrays.copyOfRange(sig, sigOff, sigOff + POINT_BYTES); + byte[] S = Arrays.copyOfRange(sig, sigOff + POINT_BYTES, sigOff + SIGNATURE_SIZE); + + if (!checkPointVar(R)) + { + return false; + } + if (!checkScalarVar(S)) + { + return false; + } + + PointAffine pA = new PointAffine(); + if (!decodePointVar(pk, pkOff, true, pA)) + { + return false; + } + + Digest d = createDigest(); + byte[] h = new byte[d.getDigestSize()]; + + dom2(d, phflag, ctx); + d.update(R, 0, POINT_BYTES); + d.update(pk, pkOff, POINT_BYTES); + d.update(m, mOff, mLen); + d.doFinal(h, 0); + + byte[] k = reduceScalar(h); + + int[] nS = new int[SCALAR_INTS]; + decodeScalar(S, 0, nS); + + int[] nA = new int[SCALAR_INTS]; + decodeScalar(k, 0, nA); + + PointAccum pR = new PointAccum(); + scalarMultStrausVar(nS, nA, pA, pR); + + byte[] check = new byte[POINT_BYTES]; + return 0 != encodePoint(pR, check, 0) && Arrays.areEqual(check, R); + } + + protected static void pointAdd(PointExt p, PointAccum r) + { + int[] a = F.create(); + int[] b = F.create(); + int[] c = F.create(); + int[] d = F.create(); + int[] e = r.u; + int[] f = F.create(); + int[] g = F.create(); + int[] h = r.v; + + F.apm(r.y, r.x, b, a); + F.apm(p.y, p.x, d, c); + F.mul(a, c, a); + F.mul(b, d, b); + F.mul(r.u, r.v, c); + F.mul(c, p.t, c); + F.mul(c, C_d2, c); + F.mul(r.z, p.z, d); + F.add(d, d, d); + F.apm(b, a, h, e); + F.apm(d, c, g, f); + F.carry(g); + F.mul(e, f, r.x); + F.mul(g, h, r.y); + F.mul(f, g, r.z); + } + + protected static void pointAdd(PointExt p, PointExt r) + { + int[] a = F.create(); + int[] b = F.create(); + int[] c = F.create(); + int[] d = F.create(); + int[] e = F.create(); + int[] f = F.create(); + int[] g = F.create(); + int[] h = F.create(); + + F.apm(p.y, p.x, b, a); + F.apm(r.y, r.x, d, c); + F.mul(a, c, a); + F.mul(b, d, b); + F.mul(p.t, r.t, c); + F.mul(c, C_d2, c); + F.mul(p.z, r.z, d); + F.add(d, d, d); + F.apm(b, a, h, e); + F.apm(d, c, g, f); + F.carry(g); + F.mul(e, f, r.x); + F.mul(g, h, r.y); + F.mul(f, g, r.z); + F.mul(e, h, r.t); + } + + protected static void pointAddVar(boolean negate, PointExt p, PointAccum r) + { + int[] a = F.create(); + int[] b = F.create(); + int[] c = F.create(); + int[] d = F.create(); + int[] e = r.u; + int[] f = F.create(); + int[] g = F.create(); + int[] h = r.v; + + int[] nc, nd, nf, ng; + if (negate) + { + nc = d; nd = c; nf = g; ng = f; + } + else + { + nc = c; nd = d; nf = f; ng = g; + } + + F.apm(r.y, r.x, b, a); + F.apm(p.y, p.x, nd, nc); + F.mul(a, c, a); + F.mul(b, d, b); + F.mul(r.u, r.v, c); + F.mul(c, p.t, c); + F.mul(c, C_d2, c); + F.mul(r.z, p.z, d); + F.add(d, d, d); + F.apm(b, a, h, e); + F.apm(d, c, ng, nf); + F.carry(ng); + F.mul(e, f, r.x); + F.mul(g, h, r.y); + F.mul(f, g, r.z); + } + + protected static void pointAddVar(boolean negate, PointExt p, PointExt q, PointExt r) + { + int[] a = F.create(); + int[] b = F.create(); + int[] c = F.create(); + int[] d = F.create(); + int[] e = F.create(); + int[] f = F.create(); + int[] g = F.create(); + int[] h = F.create(); + + int[] nc, nd, nf, ng; + if (negate) + { + nc = d; nd = c; nf = g; ng = f; + } + else + { + nc = c; nd = d; nf = f; ng = g; + } + + F.apm(p.y, p.x, b, a); + F.apm(q.y, q.x, nd, nc); + F.mul(a, c, a); + F.mul(b, d, b); + F.mul(p.t, q.t, c); + F.mul(c, C_d2, c); + F.mul(p.z, q.z, d); + F.add(d, d, d); + F.apm(b, a, h, e); + F.apm(d, c, ng, nf); + F.carry(ng); + F.mul(e, f, r.x); + F.mul(g, h, r.y); + F.mul(f, g, r.z); + F.mul(e, h, r.t); + } + + protected static void pointAddPrecomp(PointPrecomp p, PointAccum r) + { + int[] a = F.create(); + int[] b = F.create(); + int[] c = F.create(); + int[] e = r.u; + int[] f = F.create(); + int[] g = F.create(); + int[] h = r.v; + + F.apm(r.y, r.x, b, a); + F.mul(a, p.ymx_h, a); + F.mul(b, p.ypx_h, b); + F.mul(r.u, r.v, c); + F.mul(c, p.xyd, c); + F.apm(b, a, h, e); + F.apm(r.z, c, g, f); + F.carry(g); + F.mul(e, f, r.x); + F.mul(g, h, r.y); + F.mul(f, g, r.z); + } + + protected static PointExt pointCopy(PointAccum p) + { + PointExt r = new PointExt(); + F.copy(p.x, 0, r.x, 0); + F.copy(p.y, 0, r.y, 0); + F.copy(p.z, 0, r.z, 0); + F.mul(p.u, p.v, r.t); + return r; + } + + protected static PointExt pointCopy(PointAffine p) + { + PointExt r = new PointExt(); + F.copy(p.x, 0, r.x, 0); + F.copy(p.y, 0, r.y, 0); + pointExtendXY(r); + return r; + } + + protected static PointExt pointCopy(PointExt p) + { + PointExt r = new PointExt(); + pointCopy(p, r); + return r; + } + + protected static void pointCopy(PointAffine p, PointAccum r) + { + F.copy(p.x, 0, r.x, 0); + F.copy(p.y, 0, r.y, 0); + pointExtendXY(r); + } + + protected static void pointCopy(PointExt p, PointExt r) + { + F.copy(p.x, 0, r.x, 0); + F.copy(p.y, 0, r.y, 0); + F.copy(p.z, 0, r.z, 0); + F.copy(p.t, 0, r.t, 0); + } + + protected static void pointDouble(PointAccum r) + { + int[] a = F.create(); + int[] b = F.create(); + int[] c = F.create(); + int[] e = r.u; + int[] f = F.create(); + int[] g = F.create(); + int[] h = r.v; + + F.sqr(r.x, a); + F.sqr(r.y, b); + F.sqr(r.z, c); + F.add(c, c, c); + F.apm(a, b, h, g); + F.add(r.x, r.y, e); + F.sqr(e, e); + F.sub(h, e, e); + F.add(c, g, f); + F.carry(f); + F.mul(e, f, r.x); + F.mul(g, h, r.y); + F.mul(f, g, r.z); + } + + protected static void pointExtendXY(PointAccum p) + { + F.one(p.z); + F.copy(p.x, 0, p.u, 0); + F.copy(p.y, 0, p.v, 0); + } + + protected static void pointExtendXY(PointExt p) + { + F.one(p.z); + F.mul(p.x, p.y, p.t); + } + + protected static void pointLookup(int block, int index, PointPrecomp p) + { +// assert 0 <= block && block < PRECOMP_BLOCKS; +// assert 0 <= index && index < PRECOMP_POINTS; + + int off = block * PRECOMP_POINTS * 3 * F.SIZE; + + for (int i = 0; i < PRECOMP_POINTS; ++i) + { + int cond = ((i ^ index) - 1) >> 31; + F.cmov(cond, precompBase, off, p.ypx_h, 0); off += F.SIZE; + F.cmov(cond, precompBase, off, p.ymx_h, 0); off += F.SIZE; + F.cmov(cond, precompBase, off, p.xyd, 0); off += F.SIZE; + } + } + + protected static void pointLookup(int[] x, int n, int[] table, PointExt r) + { + // TODO This method is currently hardcoded to 4-bit windows and 8 precomputed points + + int w = getWindow4(x, n); + + int sign = (w >>> (4 - 1)) ^ 1; + int abs = (w ^ -sign) & 7; + +// assert sign == 0 || sign == 1; +// assert 0 <= abs && abs < 8; + + for (int i = 0, off = 0; i < 8; ++i) + { + int cond = ((i ^ abs) - 1) >> 31; + F.cmov(cond, table, off, r.x, 0); off += F.SIZE; + F.cmov(cond, table, off, r.y, 0); off += F.SIZE; + F.cmov(cond, table, off, r.z, 0); off += F.SIZE; + F.cmov(cond, table, off, r.t, 0); off += F.SIZE; + } + + F.cnegate(sign, r.x); + F.cnegate(sign, r.t); + } + + protected static void pointLookup(int[] table, int index, PointExt r) + { + int off = F.SIZE * 4 * index; + + F.copy(table, off, r.x, 0); off += F.SIZE; + F.copy(table, off, r.y, 0); off += F.SIZE; + F.copy(table, off, r.z, 0); off += F.SIZE; + F.copy(table, off, r.t, 0); + } + + protected static int[] pointPrecompute(PointAffine p, int count) + { +// assert count > 0; + + PointExt q = pointCopy(p); + PointExt d = pointCopy(q); + pointAdd(q, d); + + int[] table = F.createTable(count * 4); + int off = 0; + + int i = 0; + for (;;) + { + F.copy(q.x, 0, table, off); off += F.SIZE; + F.copy(q.y, 0, table, off); off += F.SIZE; + F.copy(q.z, 0, table, off); off += F.SIZE; + F.copy(q.t, 0, table, off); off += F.SIZE; + + if (++i == count) + { + break; + } + + pointAdd(d, q); + } + + return table; + } + + protected static PointExt[] pointPrecomputeVar(PointExt p, int count) + { +// assert count > 0; + + PointExt d = new PointExt(); + pointAddVar(false, p, p, d); + + PointExt[] table = new PointExt[count]; + table[0] = pointCopy(p); + for (int i = 1; i < count; ++i) + { + pointAddVar(false, table[i - 1], d, table[i] = new PointExt()); + } + return table; + } + + protected static void pointSetNeutral(PointAccum p) + { + F.zero(p.x); + F.one(p.y); + F.one(p.z); + F.zero(p.u); + F.one(p.v); + } + + protected static void pointSetNeutral(PointExt p) + { + F.zero(p.x); + F.one(p.y); + F.one(p.z); + F.zero(p.t); + } + + public static void precompute() + { + synchronized (precompLock) + { + if (precompBase != null) + { + return; + } + + // Precomputed table for the base point in verification ladder + { + PointExt b = new PointExt(); + F.copy(B_x, 0, b.x, 0); + F.copy(B_y, 0, b.y, 0); + pointExtendXY(b); + + precompBaseTable = pointPrecomputeVar(b, 1 << (WNAF_WIDTH_BASE - 2)); + } + + PointAccum p = new PointAccum(); + F.copy(B_x, 0, p.x, 0); + F.copy(B_y, 0, p.y, 0); + pointExtendXY(p); + + precompBase = F.createTable(PRECOMP_BLOCKS * PRECOMP_POINTS * 3); + + int off = 0; + for (int b = 0; b < PRECOMP_BLOCKS; ++b) + { + PointExt[] ds = new PointExt[PRECOMP_TEETH]; + + PointExt sum = new PointExt(); + pointSetNeutral(sum); + + for (int t = 0; t < PRECOMP_TEETH; ++t) + { + PointExt q = pointCopy(p); + pointAddVar(true, sum, q, sum); + pointDouble(p); + + ds[t] = pointCopy(p); + + if (b + t != PRECOMP_BLOCKS + PRECOMP_TEETH - 2) + { + for (int s = 1; s < PRECOMP_SPACING; ++s) + { + pointDouble(p); + } + } + } + + PointExt[] points = new PointExt[PRECOMP_POINTS]; + int k = 0; + points[k++] = sum; + + for (int t = 0; t < (PRECOMP_TEETH - 1); ++t) + { + int size = 1 << t; + for (int j = 0; j < size; ++j, ++k) + { + pointAddVar(false, points[k - size], ds[t], points[k] = new PointExt()); + } + } + +// assert k == PRECOMP_POINTS; + + int[] cs = F.createTable(PRECOMP_POINTS); + + // TODO[ed25519] A single batch inversion across all blocks? + { + int[] u = F.create(); + F.copy(points[0].z, 0, u, 0); + F.copy(u, 0, cs, 0); + + int i = 0; + while (++i < PRECOMP_POINTS) + { + F.mul(u, points[i].z, u); + F.copy(u, 0, cs, i * F.SIZE); + } + + F.add(u, u, u); + F.invVar(u, u); + --i; + + int[] t = F.create(); + + while (i > 0) + { + int j = i--; + F.copy(cs, i * F.SIZE, t, 0); + F.mul(t, u, t); + F.copy(t, 0, cs, j * F.SIZE); + F.mul(u, points[j].z, u); + } + + F.copy(u, 0, cs, 0); + } + + for (int i = 0; i < PRECOMP_POINTS; ++i) + { + PointExt q = points[i]; + + int[] x = F.create(); + int[] y = F.create(); + +// F.add(q.z, q.z, x); +// F.invVar(x, y); + F.copy(cs, i * F.SIZE, y, 0); + + F.mul(q.x, y, x); + F.mul(q.y, y, y); + + PointPrecomp r = new PointPrecomp(); + F.apm(y, x, r.ypx_h, r.ymx_h); + F.mul(x, y, r.xyd); + F.mul(r.xyd, C_d4, r.xyd); + + F.normalize(r.ypx_h); + F.normalize(r.ymx_h); +// F.normalize(r.xyd); + + F.copy(r.ypx_h, 0, precompBase, off); off += F.SIZE; + F.copy(r.ymx_h, 0, precompBase, off); off += F.SIZE; + F.copy(r.xyd, 0, precompBase, off); off += F.SIZE; + } + } + +// assert off == precompBase.length; + } + } + + protected static void pruneScalar(byte[] n, int nOff, byte[] r) + { + System.arraycopy(n, nOff, r, 0, SCALAR_BYTES); + + r[0] &= 0xF8; + r[SCALAR_BYTES - 1] &= 0x7F; + r[SCALAR_BYTES - 1] |= 0x40; + } + + protected static byte[] reduceScalar(byte[] n) + { + long x00 = decode32(n, 0) & M32L; // x00:32/-- + long x01 = (decode24(n, 4) << 4) & M32L; // x01:28/-- + long x02 = decode32(n, 7) & M32L; // x02:32/-- + long x03 = (decode24(n, 11) << 4) & M32L; // x03:28/-- + long x04 = decode32(n, 14) & M32L; // x04:32/-- + long x05 = (decode24(n, 18) << 4) & M32L; // x05:28/-- + long x06 = decode32(n, 21) & M32L; // x06:32/-- + long x07 = (decode24(n, 25) << 4) & M32L; // x07:28/-- + long x08 = decode32(n, 28) & M32L; // x08:32/-- + long x09 = (decode24(n, 32) << 4) & M32L; // x09:28/-- + long x10 = decode32(n, 35) & M32L; // x10:32/-- + long x11 = (decode24(n, 39) << 4) & M32L; // x11:28/-- + long x12 = decode32(n, 42) & M32L; // x12:32/-- + long x13 = (decode24(n, 46) << 4) & M32L; // x13:28/-- + long x14 = decode32(n, 49) & M32L; // x14:32/-- + long x15 = (decode24(n, 53) << 4) & M32L; // x15:28/-- + long x16 = decode32(n, 56) & M32L; // x16:32/-- + long x17 = (decode24(n, 60) << 4) & M32L; // x17:28/-- + long x18 = n[63] & M08L; // x18:08/-- + long t; + +// x18 += (x17 >> 28); x17 &= M28L; + x09 -= x18 * L0; // x09:34/28 + x10 -= x18 * L1; // x10:33/30 + x11 -= x18 * L2; // x11:35/28 + x12 -= x18 * L3; // x12:32/31 + x13 -= x18 * L4; // x13:28/21 + + x17 += (x16 >> 28); x16 &= M28L; // x17:28/--, x16:28/-- + x08 -= x17 * L0; // x08:54/32 + x09 -= x17 * L1; // x09:52/51 + x10 -= x17 * L2; // x10:55/34 + x11 -= x17 * L3; // x11:51/36 + x12 -= x17 * L4; // x12:41/-- + +// x16 += (x15 >> 28); x15 &= M28L; + x07 -= x16 * L0; // x07:54/28 + x08 -= x16 * L1; // x08:54/53 + x09 -= x16 * L2; // x09:55/53 + x10 -= x16 * L3; // x10:55/52 + x11 -= x16 * L4; // x11:51/41 + + x15 += (x14 >> 28); x14 &= M28L; // x15:28/--, x14:28/-- + x06 -= x15 * L0; // x06:54/32 + x07 -= x15 * L1; // x07:54/53 + x08 -= x15 * L2; // x08:56/-- + x09 -= x15 * L3; // x09:55/54 + x10 -= x15 * L4; // x10:55/53 + +// x14 += (x13 >> 28); x13 &= M28L; + x05 -= x14 * L0; // x05:54/28 + x06 -= x14 * L1; // x06:54/53 + x07 -= x14 * L2; // x07:56/-- + x08 -= x14 * L3; // x08:56/51 + x09 -= x14 * L4; // x09:56/-- + + x13 += (x12 >> 28); x12 &= M28L; // x13:28/22, x12:28/-- + x04 -= x13 * L0; // x04:54/49 + x05 -= x13 * L1; // x05:54/53 + x06 -= x13 * L2; // x06:56/-- + x07 -= x13 * L3; // x07:56/52 + x08 -= x13 * L4; // x08:56/52 + + x12 += (x11 >> 28); x11 &= M28L; // x12:28/24, x11:28/-- + x03 -= x12 * L0; // x03:54/49 + x04 -= x12 * L1; // x04:54/51 + x05 -= x12 * L2; // x05:56/-- + x06 -= x12 * L3; // x06:56/52 + x07 -= x12 * L4; // x07:56/53 + + x11 += (x10 >> 28); x10 &= M28L; // x11:29/--, x10:28/-- + x02 -= x11 * L0; // x02:55/32 + x03 -= x11 * L1; // x03:55/-- + x04 -= x11 * L2; // x04:56/55 + x05 -= x11 * L3; // x05:56/52 + x06 -= x11 * L4; // x06:56/53 + + x10 += (x09 >> 28); x09 &= M28L; // x10:29/--, x09:28/-- + x01 -= x10 * L0; // x01:55/28 + x02 -= x10 * L1; // x02:55/54 + x03 -= x10 * L2; // x03:56/55 + x04 -= x10 * L3; // x04:57/-- + x05 -= x10 * L4; // x05:56/53 + + x08 += (x07 >> 28); x07 &= M28L; // x08:56/53, x07:28/-- + x09 += (x08 >> 28); x08 &= M28L; // x09:29/25, x08:28/-- + + t = x08 >>> 27; + x09 += t; // x09:29/26 + + x00 -= x09 * L0; // x00:55/53 + x01 -= x09 * L1; // x01:55/54 + x02 -= x09 * L2; // x02:57/-- + x03 -= x09 * L3; // x03:57/-- + x04 -= x09 * L4; // x04:57/42 + + x01 += (x00 >> 28); x00 &= M28L; + x02 += (x01 >> 28); x01 &= M28L; + x03 += (x02 >> 28); x02 &= M28L; + x04 += (x03 >> 28); x03 &= M28L; + x05 += (x04 >> 28); x04 &= M28L; + x06 += (x05 >> 28); x05 &= M28L; + x07 += (x06 >> 28); x06 &= M28L; + x08 += (x07 >> 28); x07 &= M28L; + x09 = (x08 >> 28); x08 &= M28L; + + x09 -= t; + +// assert x09 == 0L || x09 == -1L; + + x00 += x09 & L0; + x01 += x09 & L1; + x02 += x09 & L2; + x03 += x09 & L3; + x04 += x09 & L4; + + x01 += (x00 >> 28); x00 &= M28L; + x02 += (x01 >> 28); x01 &= M28L; + x03 += (x02 >> 28); x02 &= M28L; + x04 += (x03 >> 28); x03 &= M28L; + x05 += (x04 >> 28); x04 &= M28L; + x06 += (x05 >> 28); x05 &= M28L; + x07 += (x06 >> 28); x06 &= M28L; + x08 += (x07 >> 28); x07 &= M28L; + + byte[] r = new byte[SCALAR_BYTES]; + encode56(x00 | (x01 << 28), r, 0); + encode56(x02 | (x03 << 28), r, 7); + encode56(x04 | (x05 << 28), r, 14); + encode56(x06 | (x07 << 28), r, 21); + encode32((int)x08, r, 28); + return r; + } + + protected static void scalarMult(byte[] k, PointAffine p, PointAccum r) + { + int[] n = new int[SCALAR_INTS]; + decodeScalar(k, 0, n); + +// assert 0 == (n[0] & 7); +// assert 1 == n[SCALAR_INTS - 1] >>> 30; + + Nat.shiftDownBits(SCALAR_INTS, n, 3, 1); + + // Recode the scalar into signed-digit form + { + //int c1 = + Nat.cadd(SCALAR_INTS, ~n[0] & 1, n, L, n); //assert c1 == 0; + //int c2 = + Nat.shiftDownBit(SCALAR_INTS, n, 0); //assert c2 == (1 << 31); + } + +// assert 1 == n[SCALAR_INTS - 1] >>> 28; + + int[] table = pointPrecompute(p, 8); + PointExt q = new PointExt(); + + // Replace first 4 doublings (2^4 * P) with 1 addition (P + 15 * P) + pointCopy(p, r); + pointLookup(table, 7, q); + pointAdd(q, r); + + int w = 62; + for (;;) + { + pointLookup(n, w, table, q); + pointAdd(q, r); + + pointDouble(r); + pointDouble(r); + pointDouble(r); + + if (--w < 0) + { + break; + } + + pointDouble(r); + } + } + + protected static void scalarMultBase(byte[] k, PointAccum r) + { + precompute(); + + int[] n = new int[SCALAR_INTS]; + decodeScalar(k, 0, n); + + // Recode the scalar into signed-digit form, then group comb bits in each block + { + //int c1 = + Nat.cadd(SCALAR_INTS, ~n[0] & 1, n, L, n); //assert c1 == 0; + //int c2 = + Nat.shiftDownBit(SCALAR_INTS, n, 1); //assert c2 == (1 << 31); + + for (int i = 0; i < SCALAR_INTS; ++i) + { + n[i] = Interleave.shuffle2(n[i]); + } + } + + PointPrecomp p = new PointPrecomp(); + + pointSetNeutral(r); + + int cOff = (PRECOMP_SPACING - 1) * PRECOMP_TEETH; + for (;;) + { + for (int b = 0; b < PRECOMP_BLOCKS; ++b) + { + int w = n[b] >>> cOff; + int sign = (w >>> (PRECOMP_TEETH - 1)) & 1; + int abs = (w ^ -sign) & PRECOMP_MASK; + +// assert sign == 0 || sign == 1; +// assert 0 <= abs && abs < PRECOMP_POINTS; + + pointLookup(b, abs, p); + + F.cswap(sign, p.ypx_h, p.ymx_h); + F.cnegate(sign, p.xyd); + + pointAddPrecomp(p, r); + } + + if ((cOff -= PRECOMP_TEETH) < 0) + { + break; + } + + pointDouble(r); + } + } + + protected static void scalarMultBaseEncoded(byte[] k, byte[] r, int rOff) + { + PointAccum p = new PointAccum(); + scalarMultBase(k, p); + if (0 == encodePoint(p, r, rOff)) + { + throw new IllegalStateException(); + } + } + + /** + * NOTE: Only for use by X25519 + */ + public static void scalarMultBaseYZ(X25519.Friend friend, byte[] k, int kOff, int[] y, int[] z) + { + if (null == friend) + { + throw new NullPointerException("This method is only for use by X25519"); + } + + byte[] n = new byte[SCALAR_BYTES]; + pruneScalar(k, kOff, n); + + PointAccum p = new PointAccum(); + scalarMultBase(n, p); + if (0 == checkPoint(p.x, p.y, p.z)) + { + throw new IllegalStateException(); + } + F.copy(p.y, 0, y, 0); + F.copy(p.z, 0, z, 0); + } + + protected static void scalarMultStrausVar(int[] nb, int[] np, PointAffine p, PointAccum r) + { + precompute(); + + final int width = 5; + + byte[] ws_b = getWnafVar(nb, WNAF_WIDTH_BASE); + byte[] ws_p = getWnafVar(np, width); + + PointExt[] tp = pointPrecomputeVar(pointCopy(p), 1 << (width - 2)); + + pointSetNeutral(r); + + for (int bit = 252;;) + { + int wb = ws_b[bit]; + if (wb != 0) + { + int sign = wb >> 31; + int index = (wb ^ sign) >>> 1; + + pointAddVar((sign != 0), precompBaseTable[index], r); + } + + int wp = ws_p[bit]; + if (wp != 0) + { + int sign = wp >> 31; + int index = (wp ^ sign) >>> 1; + + pointAddVar((sign != 0), tp[index], r); + } + + if (--bit < 0) + { + break; + } + + pointDouble(r); + } + } + + public static void sign(byte[] sk, int skOff, byte[] m, int mOff, int mLen, byte[] sig, int sigOff) + { + byte[] ctx = null; + byte phflag = 0x00; + + implSign(sk, skOff, ctx, phflag, m, mOff, mLen, sig, sigOff); + } + + public static void sign(byte[] sk, int skOff, byte[] pk, int pkOff, byte[] m, int mOff, int mLen, byte[] sig, int sigOff) + { + byte[] ctx = null; + byte phflag = 0x00; + + implSign(sk, skOff, pk, pkOff, ctx, phflag, m, mOff, mLen, sig, sigOff); + } + + public static void sign(byte[] sk, int skOff, byte[] ctx, byte[] m, int mOff, int mLen, byte[] sig, int sigOff) + { + byte phflag = 0x00; + + implSign(sk, skOff, ctx, phflag, m, mOff, mLen, sig, sigOff); + } + + public static void sign(byte[] sk, int skOff, byte[] pk, int pkOff, byte[] ctx, byte[] m, int mOff, int mLen, byte[] sig, int sigOff) + { + byte phflag = 0x00; + + implSign(sk, skOff, pk, pkOff, ctx, phflag, m, mOff, mLen, sig, sigOff); + } + + public static void signPrehash(byte[] sk, int skOff, byte[] ctx, byte[] ph, int phOff, byte[] sig, int sigOff) + { + byte phflag = 0x01; + + implSign(sk, skOff, ctx, phflag, ph, phOff, PREHASH_SIZE, sig, sigOff); + } + + public static void signPrehash(byte[] sk, int skOff, byte[] pk, int pkOff, byte[] ctx, byte[] ph, int phOff, byte[] sig, int sigOff) + { + byte phflag = 0x01; + + implSign(sk, skOff, pk, pkOff, ctx, phflag, ph, phOff, PREHASH_SIZE, sig, sigOff); + } + + public static void signPrehash(byte[] sk, int skOff, byte[] ctx, Digest ph, byte[] sig, int sigOff) + { + byte[] m = new byte[PREHASH_SIZE]; + if (PREHASH_SIZE != ph.doFinal(m, 0)) + { + throw new IllegalArgumentException("ph"); + } + + byte phflag = 0x01; + + implSign(sk, skOff, ctx, phflag, m, 0, m.length, sig, sigOff); + } + + public static void signPrehash(byte[] sk, int skOff, byte[] pk, int pkOff, byte[] ctx, Digest ph, byte[] sig, int sigOff) + { + byte[] m = new byte[PREHASH_SIZE]; + if (PREHASH_SIZE != ph.doFinal(m, 0)) + { + throw new IllegalArgumentException("ph"); + } + + byte phflag = 0x01; + + implSign(sk, skOff, pk, pkOff, ctx, phflag, m, 0, m.length, sig, sigOff); + } + + public static boolean verify(byte[] sig, int sigOff, byte[] pk, int pkOff, byte[] m, int mOff, int mLen) + { + byte[] ctx = null; + byte phflag = 0x00; + + return implVerify(sig, sigOff, pk, pkOff, ctx, phflag, m, mOff, mLen); + } + + public static boolean verify(byte[] sig, int sigOff, byte[] pk, int pkOff, byte[] ctx, byte[] m, int mOff, int mLen) + { + byte phflag = 0x00; + + return implVerify(sig, sigOff, pk, pkOff, ctx, phflag, m, mOff, mLen); + } + + public static boolean verifyPrehash(byte[] sig, int sigOff, byte[] pk, int pkOff, byte[] ctx, byte[] ph, int phOff) + { + byte phflag = 0x01; + + return implVerify(sig, sigOff, pk, pkOff, ctx, phflag, ph, phOff, PREHASH_SIZE); + } + + public static boolean verifyPrehash(byte[] sig, int sigOff, byte[] pk, int pkOff, byte[] ctx, Digest ph) + { + byte[] m = new byte[PREHASH_SIZE]; + if (PREHASH_SIZE != ph.doFinal(m, 0)) + { + throw new IllegalArgumentException("ph"); + } + + byte phflag = 0x01; + + return implVerify(sig, sigOff, pk, pkOff, ctx, phflag, m, 0, m.length); + } +} diff --git a/src/main/java/org/qortal/crypto/Crypto.java b/src/main/java/org/qortal/crypto/Crypto.java index 5d91781c..8ee0b2b2 100644 --- a/src/main/java/org/qortal/crypto/Crypto.java +++ b/src/main/java/org/qortal/crypto/Crypto.java @@ -270,10 +270,10 @@ public abstract class Crypto { } public static byte[] getSharedSecret(byte[] privateKey, byte[] publicKey) { - byte[] x25519PrivateKey = BouncyCastle25519.toX25519PrivateKey(privateKey); + byte[] x25519PrivateKey = Qortal25519Extras.toX25519PrivateKey(privateKey); X25519PrivateKeyParameters xPrivateKeyParams = new X25519PrivateKeyParameters(x25519PrivateKey, 0); - byte[] x25519PublicKey = BouncyCastle25519.toX25519PublicKey(publicKey); + byte[] x25519PublicKey = Qortal25519Extras.toX25519PublicKey(publicKey); X25519PublicKeyParameters xPublicKeyParams = new X25519PublicKeyParameters(x25519PublicKey, 0); byte[] sharedSecret = new byte[SHARED_SECRET_LENGTH]; diff --git a/src/main/java/org/qortal/crypto/Qortal25519Extras.java b/src/main/java/org/qortal/crypto/Qortal25519Extras.java new file mode 100644 index 00000000..42cca93e --- /dev/null +++ b/src/main/java/org/qortal/crypto/Qortal25519Extras.java @@ -0,0 +1,234 @@ +package org.qortal.crypto; + +import org.bouncycastle.crypto.Digest; +import org.bouncycastle.crypto.digests.SHA512Digest; +import org.bouncycastle.math.ec.rfc7748.X25519; +import org.bouncycastle.math.ec.rfc7748.X25519Field; +import org.bouncycastle.math.ec.rfc8032.Ed25519; +import org.bouncycastle.math.raw.Nat256; + +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.Collection; + +/** + * Additions to BouncyCastle providing: + *

+ * + */ +public abstract class Qortal25519Extras extends BouncyCastleEd25519 { + + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + + public static byte[] toX25519PublicKey(byte[] ed25519PublicKey) { + int[] one = new int[X25519Field.SIZE]; + X25519Field.one(one); + + PointAffine pA = new PointAffine(); + if (!decodePointVar(ed25519PublicKey, 0, true, pA)) + return null; + + int[] y = pA.y; + + int[] oneMinusY = new int[X25519Field.SIZE]; + X25519Field.sub(one, y, oneMinusY); + + int[] onePlusY = new int[X25519Field.SIZE]; + X25519Field.add(one, y, onePlusY); + + int[] oneMinusYInverted = new int[X25519Field.SIZE]; + X25519Field.inv(oneMinusY, oneMinusYInverted); + + int[] u = new int[X25519Field.SIZE]; + X25519Field.mul(onePlusY, oneMinusYInverted, u); + + X25519Field.normalize(u); + + byte[] x25519PublicKey = new byte[X25519.SCALAR_SIZE]; + X25519Field.encode(u, x25519PublicKey, 0); + + return x25519PublicKey; + } + + public static byte[] toX25519PrivateKey(byte[] ed25519PrivateKey) { + Digest d = Ed25519.createPrehash(); + byte[] h = new byte[d.getDigestSize()]; + + d.update(ed25519PrivateKey, 0, ed25519PrivateKey.length); + d.doFinal(h, 0); + + byte[] s = new byte[X25519.SCALAR_SIZE]; + + System.arraycopy(h, 0, s, 0, X25519.SCALAR_SIZE); + s[0] &= 0xF8; + s[X25519.SCALAR_SIZE - 1] &= 0x7F; + s[X25519.SCALAR_SIZE - 1] |= 0x40; + + return s; + } + + // Mostly for test support + public static PointAccum newPointAccum() { + return new PointAccum(); + } + + public static byte[] aggregatePublicKeys(Collection publicKeys) { + PointAccum rAccum = null; + + for (byte[] publicKey : publicKeys) { + PointAffine pA = new PointAffine(); + if (!decodePointVar(publicKey, 0, false, pA)) + // Failed to decode + return null; + + if (rAccum == null) { + rAccum = new PointAccum(); + pointCopy(pA, rAccum); + } else { + pointAdd(pointCopy(pA), rAccum); + } + } + + byte[] publicKey = new byte[SCALAR_BYTES]; + if (0 == encodePoint(rAccum, publicKey, 0)) + // Failed to encode + return null; + + return publicKey; + } + + public static byte[] aggregateSignatures(Collection signatures) { + // Signatures are (R, s) + // R is a point + // s is a scalar + PointAccum rAccum = null; + int[] sAccum = new int[SCALAR_INTS]; + + byte[] rEncoded = new byte[POINT_BYTES]; + int[] sPart = new int[SCALAR_INTS]; + for (byte[] signature : signatures) { + System.arraycopy(signature,0, rEncoded, 0, rEncoded.length); + + PointAffine pA = new PointAffine(); + if (!decodePointVar(rEncoded, 0, false, pA)) + // Failed to decode + return null; + + if (rAccum == null) { + rAccum = new PointAccum(); + pointCopy(pA, rAccum); + + decode32(signature, rEncoded.length, sAccum, 0, SCALAR_INTS); + } else { + pointAdd(pointCopy(pA), rAccum); + + decode32(signature, rEncoded.length, sPart, 0, SCALAR_INTS); + Nat256.addTo(sPart, sAccum); + + // "mod L" on sAccum + if (Nat256.gte(sAccum, L)) + Nat256.subFrom(L, sAccum); + } + } + + byte[] signature = new byte[SIGNATURE_SIZE]; + if (0 == encodePoint(rAccum, signature, 0)) + // Failed to encode + return null; + + for (int i = 0; i < sAccum.length; ++i) { + encode32(sAccum[i], signature, POINT_BYTES + i * 4); + } + + return signature; + } + + public static byte[] signForAggregation(byte[] privateKey, byte[] message) { + // Very similar to BouncyCastle's implementation except we use secure random nonce and different hash + Digest d = new SHA512Digest(); + byte[] h = new byte[d.getDigestSize()]; + + d.reset(); + d.update(privateKey, 0, privateKey.length); + d.doFinal(h, 0); + + byte[] sH = new byte[SCALAR_BYTES]; + pruneScalar(h, 0, sH); + + byte[] publicKey = new byte[SCALAR_BYTES]; + scalarMultBaseEncoded(sH, publicKey, 0); + + byte[] rSeed = new byte[d.getDigestSize()]; + SECURE_RANDOM.nextBytes(rSeed); + + byte[] r = new byte[SCALAR_BYTES]; + pruneScalar(rSeed, 0, r); + + byte[] R = new byte[POINT_BYTES]; + scalarMultBaseEncoded(r, R, 0); + + d.reset(); + d.update(message, 0, message.length); + d.doFinal(h, 0); + byte[] k = reduceScalar(h); + + byte[] s = calculateS(r, k, sH); + + byte[] signature = new byte[SIGNATURE_SIZE]; + System.arraycopy(R, 0, signature, 0, POINT_BYTES); + System.arraycopy(s, 0, signature, POINT_BYTES, SCALAR_BYTES); + + return signature; + } + + public static boolean verifyAggregated(byte[] publicKey, byte[] signature, byte[] message) { + byte[] R = Arrays.copyOfRange(signature, 0, POINT_BYTES); + + byte[] s = Arrays.copyOfRange(signature, POINT_BYTES, POINT_BYTES + SCALAR_BYTES); + + if (!checkPointVar(R)) + // R out of bounds + return false; + + if (!checkScalarVar(s)) + // s out of bounds + return false; + + byte[] S = new byte[POINT_BYTES]; + scalarMultBaseEncoded(s, S, 0); + + PointAffine pA = new PointAffine(); + if (!decodePointVar(publicKey, 0, true, pA)) + // Failed to decode + return false; + + Digest d = new SHA512Digest(); + byte[] h = new byte[d.getDigestSize()]; + + d.update(message, 0, message.length); + d.doFinal(h, 0); + + byte[] k = reduceScalar(h); + + int[] nS = new int[SCALAR_INTS]; + decodeScalar(s, 0, nS); + + int[] nA = new int[SCALAR_INTS]; + decodeScalar(k, 0, nA); + + /*PointAccum*/ + PointAccum pR = new PointAccum(); + scalarMultStrausVar(nS, nA, pA, pR); + + byte[] check = new byte[POINT_BYTES]; + if (0 == encodePoint(pR, check, 0)) + // Failed to encode + return false; + + return Arrays.equals(check, R); + } +} diff --git a/src/test/java/org/qortal/test/CryptoTests.java b/src/test/java/org/qortal/test/CryptoTests.java index 6a0133d2..2cc73182 100644 --- a/src/test/java/org/qortal/test/CryptoTests.java +++ b/src/test/java/org/qortal/test/CryptoTests.java @@ -4,7 +4,7 @@ import org.junit.Test; import org.qortal.account.PrivateKeyAccount; import org.qortal.block.BlockChain; import org.qortal.crypto.AES; -import org.qortal.crypto.BouncyCastle25519; +import org.qortal.crypto.Qortal25519Extras; import org.qortal.crypto.Crypto; import org.qortal.test.common.Common; import org.qortal.utils.Base58; @@ -123,14 +123,14 @@ public class CryptoTests extends Common { random.nextBytes(ed25519PrivateKey); PrivateKeyAccount account = new PrivateKeyAccount(null, ed25519PrivateKey); - byte[] x25519PrivateKey = BouncyCastle25519.toX25519PrivateKey(account.getPrivateKey()); + byte[] x25519PrivateKey = Qortal25519Extras.toX25519PrivateKey(account.getPrivateKey()); X25519PrivateKeyParameters x25519PrivateKeyParams = new X25519PrivateKeyParameters(x25519PrivateKey, 0); // Derive X25519 public key from X25519 private key byte[] x25519PublicKeyFromPrivate = x25519PrivateKeyParams.generatePublicKey().getEncoded(); // Derive X25519 public key from Ed25519 public key - byte[] x25519PublicKeyFromEd25519 = BouncyCastle25519.toX25519PublicKey(account.getPublicKey()); + byte[] x25519PublicKeyFromEd25519 = Qortal25519Extras.toX25519PublicKey(account.getPublicKey()); assertEquals(String.format("Public keys do not match, from private key %s", Base58.encode(ed25519PrivateKey)), Base58.encode(x25519PublicKeyFromPrivate), Base58.encode(x25519PublicKeyFromEd25519)); } @@ -162,10 +162,10 @@ public class CryptoTests extends Common { } private static byte[] calcBCSharedSecret(byte[] ed25519PrivateKey, byte[] ed25519PublicKey) { - byte[] x25519PrivateKey = BouncyCastle25519.toX25519PrivateKey(ed25519PrivateKey); + byte[] x25519PrivateKey = Qortal25519Extras.toX25519PrivateKey(ed25519PrivateKey); X25519PrivateKeyParameters privateKeyParams = new X25519PrivateKeyParameters(x25519PrivateKey, 0); - byte[] x25519PublicKey = BouncyCastle25519.toX25519PublicKey(ed25519PublicKey); + byte[] x25519PublicKey = Qortal25519Extras.toX25519PublicKey(ed25519PublicKey); X25519PublicKeyParameters publicKeyParams = new X25519PublicKeyParameters(x25519PublicKey, 0); byte[] sharedSecret = new byte[32]; @@ -186,10 +186,10 @@ public class CryptoTests extends Common { final String expectedTheirX25519PublicKey = "ANjnZLRSzW9B1aVamiYGKP3XtBooU9tGGDjUiibUfzp2"; final String expectedSharedSecret = "DTMZYG96x8XZuGzDvHFByVLsXedimqtjiXHhXPVe58Ap"; - byte[] ourX25519PrivateKey = BouncyCastle25519.toX25519PrivateKey(ourPrivateKey); + byte[] ourX25519PrivateKey = Qortal25519Extras.toX25519PrivateKey(ourPrivateKey); assertEquals("X25519 private key incorrect", expectedOurX25519PrivateKey, Base58.encode(ourX25519PrivateKey)); - byte[] theirX25519PublicKey = BouncyCastle25519.toX25519PublicKey(theirPublicKey); + byte[] theirX25519PublicKey = Qortal25519Extras.toX25519PublicKey(theirPublicKey); assertEquals("X25519 public key incorrect", expectedTheirX25519PublicKey, Base58.encode(theirX25519PublicKey)); byte[] sharedSecret = calcBCSharedSecret(ourPrivateKey, theirPublicKey); diff --git a/src/test/java/org/qortal/test/SchnorrTests.java b/src/test/java/org/qortal/test/SchnorrTests.java new file mode 100644 index 00000000..03c92d2f --- /dev/null +++ b/src/test/java/org/qortal/test/SchnorrTests.java @@ -0,0 +1,190 @@ +package org.qortal.test; + +import com.google.common.hash.HashCode; +import com.google.common.primitives.Bytes; +import com.google.common.primitives.Longs; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.jsse.provider.BouncyCastleJsseProvider; +import org.junit.Test; +import org.qortal.crypto.Qortal25519Extras; +import org.qortal.data.network.OnlineAccountData; +import org.qortal.transform.Transformer; + +import java.math.BigInteger; +import java.security.SecureRandom; +import java.security.Security; +import java.util.*; +import java.util.stream.Collectors; + +import static org.junit.Assert.*; + +public class SchnorrTests extends Qortal25519Extras { + + static { + // This must go before any calls to LogManager/Logger + System.setProperty("java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager"); + + Security.insertProviderAt(new BouncyCastleProvider(), 0); + Security.insertProviderAt(new BouncyCastleJsseProvider(), 1); + } + + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + + @Test + public void testConversion() { + // Scalar form + byte[] scalarA = HashCode.fromString("0100000000000000000000000000000000000000000000000000000000000000".toLowerCase()).asBytes(); + System.out.printf("a: %s%n", HashCode.fromBytes(scalarA)); + + byte[] pointA = HashCode.fromString("5866666666666666666666666666666666666666666666666666666666666666".toLowerCase()).asBytes(); + + BigInteger expectedY = new BigInteger("46316835694926478169428394003475163141307993866256225615783033603165251855960"); + + PointAccum pointAccum = Qortal25519Extras.newPointAccum(); + scalarMultBase(scalarA, pointAccum); + + byte[] encoded = new byte[POINT_BYTES]; + if (0 == encodePoint(pointAccum, encoded, 0)) + fail("Point encoding failed"); + + System.out.printf("aG: %s%n", HashCode.fromBytes(encoded)); + assertArrayEquals(pointA, encoded); + + byte[] yBytes = new byte[POINT_BYTES]; + System.arraycopy(encoded,0, yBytes, 0, encoded.length); + Bytes.reverse(yBytes); + + System.out.printf("yBytes: %s%n", HashCode.fromBytes(yBytes)); + BigInteger yBI = new BigInteger(yBytes); + + System.out.printf("aG y: %s%n", yBI); + assertEquals(expectedY, yBI); + } + + @Test + public void testAddition() { + /* + * 1G: b'5866666666666666666666666666666666666666666666666666666666666666' + * 2G: b'c9a3f86aae465f0e56513864510f3997561fa2c9e85ea21dc2292309f3cd6022' + * 3G: b'd4b4f5784868c3020403246717ec169ff79e26608ea126a1ab69ee77d1b16712' + */ + + // Scalar form + byte[] s1 = HashCode.fromString("0100000000000000000000000000000000000000000000000000000000000000".toLowerCase()).asBytes(); + byte[] s2 = HashCode.fromString("0200000000000000000000000000000000000000000000000000000000000000".toLowerCase()).asBytes(); + + // Point form + byte[] g1 = HashCode.fromString("5866666666666666666666666666666666666666666666666666666666666666".toLowerCase()).asBytes(); + byte[] g2 = HashCode.fromString("c9a3f86aae465f0e56513864510f3997561fa2c9e85ea21dc2292309f3cd6022".toLowerCase()).asBytes(); + byte[] g3 = HashCode.fromString("d4b4f5784868c3020403246717ec169ff79e26608ea126a1ab69ee77d1b16712".toLowerCase()).asBytes(); + + PointAccum p1 = Qortal25519Extras.newPointAccum(); + scalarMultBase(s1, p1); + + PointAccum p2 = Qortal25519Extras.newPointAccum(); + scalarMultBase(s2, p2); + + pointAdd(pointCopy(p1), p2); + + byte[] encoded = new byte[POINT_BYTES]; + if (0 == encodePoint(p2, encoded, 0)) + fail("Point encoding failed"); + + System.out.printf("sum: %s%n", HashCode.fromBytes(encoded)); + assertArrayEquals(g3, encoded); + } + + @Test + public void testSimpleSign() { + byte[] privateKey = HashCode.fromString("0100000000000000000000000000000000000000000000000000000000000000".toLowerCase()).asBytes(); + byte[] message = HashCode.fromString("01234567".toLowerCase()).asBytes(); + + byte[] signature = signForAggregation(privateKey, message); + System.out.printf("signature: %s%n", HashCode.fromBytes(signature)); + } + + @Test + public void testSimpleVerify() { + byte[] privateKey = HashCode.fromString("0100000000000000000000000000000000000000000000000000000000000000".toLowerCase()).asBytes(); + byte[] message = HashCode.fromString("01234567".toLowerCase()).asBytes(); + byte[] signature = HashCode.fromString("13e58e88f3df9e06637d2d5bbb814c028e3ba135494530b9d3b120bdb31168d62c70a37ae9cfba816fe6038ee1ce2fb521b95c4a91c7ff0bb1dd2e67733f2b0d".toLowerCase()).asBytes(); + + byte[] publicKey = new byte[Transformer.PUBLIC_KEY_LENGTH]; + Qortal25519Extras.generatePublicKey(privateKey, 0, publicKey, 0); + + assertTrue(verifyAggregated(publicKey, signature, message)); + } + + @Test + public void testSimpleSignAndVerify() { + byte[] privateKey = HashCode.fromString("0100000000000000000000000000000000000000000000000000000000000000".toLowerCase()).asBytes(); + byte[] message = HashCode.fromString("01234567".toLowerCase()).asBytes(); + + byte[] signature = signForAggregation(privateKey, message); + + byte[] publicKey = new byte[Transformer.PUBLIC_KEY_LENGTH]; + Qortal25519Extras.generatePublicKey(privateKey, 0, publicKey, 0); + + assertTrue(verifyAggregated(publicKey, signature, message)); + } + + @Test + public void testSimpleAggregate() { + List onlineAccounts = generateOnlineAccounts(1); + + byte[] aggregatePublicKey = aggregatePublicKeys(onlineAccounts.stream().map(OnlineAccountData::getPublicKey).collect(Collectors.toUnmodifiableList())); + System.out.printf("Aggregate public key: %s%n", HashCode.fromBytes(aggregatePublicKey)); + + byte[] aggregateSignature = aggregateSignatures(onlineAccounts.stream().map(OnlineAccountData::getSignature).collect(Collectors.toUnmodifiableList())); + System.out.printf("Aggregate signature: %s%n", HashCode.fromBytes(aggregateSignature)); + + OnlineAccountData onlineAccount = onlineAccounts.get(0); + + assertArrayEquals(String.format("expected: %s, actual: %s", HashCode.fromBytes(onlineAccount.getPublicKey()), HashCode.fromBytes(aggregatePublicKey)), onlineAccount.getPublicKey(), aggregatePublicKey); + assertArrayEquals(String.format("expected: %s, actual: %s", HashCode.fromBytes(onlineAccount.getSignature()), HashCode.fromBytes(aggregateSignature)), onlineAccount.getSignature(), aggregateSignature); + + // This is the crucial test: + long timestamp = onlineAccount.getTimestamp(); + byte[] timestampBytes = Longs.toByteArray(timestamp); + assertTrue(verifyAggregated(aggregatePublicKey, aggregateSignature, timestampBytes)); + } + + @Test + public void testMultipleAggregate() { + List onlineAccounts = generateOnlineAccounts(5000); + + byte[] aggregatePublicKey = aggregatePublicKeys(onlineAccounts.stream().map(OnlineAccountData::getPublicKey).collect(Collectors.toUnmodifiableList())); + System.out.printf("Aggregate public key: %s%n", HashCode.fromBytes(aggregatePublicKey)); + + byte[] aggregateSignature = aggregateSignatures(onlineAccounts.stream().map(OnlineAccountData::getSignature).collect(Collectors.toUnmodifiableList())); + System.out.printf("Aggregate signature: %s%n", HashCode.fromBytes(aggregateSignature)); + + OnlineAccountData onlineAccount = onlineAccounts.get(0); + + // This is the crucial test: + long timestamp = onlineAccount.getTimestamp(); + byte[] timestampBytes = Longs.toByteArray(timestamp); + assertTrue(verifyAggregated(aggregatePublicKey, aggregateSignature, timestampBytes)); + } + + private List generateOnlineAccounts(int numAccounts) { + List onlineAccounts = new ArrayList<>(); + + long timestamp = System.currentTimeMillis(); + byte[] timestampBytes = Longs.toByteArray(timestamp); + + for (int a = 0; a < numAccounts; ++a) { + byte[] privateKey = new byte[Transformer.PUBLIC_KEY_LENGTH]; + SECURE_RANDOM.nextBytes(privateKey); + + byte[] publicKey = new byte[Transformer.PUBLIC_KEY_LENGTH]; + Qortal25519Extras.generatePublicKey(privateKey, 0, publicKey, 0); + + byte[] signature = signForAggregation(privateKey, timestampBytes); + + onlineAccounts.add(new OnlineAccountData(timestamp, signature, publicKey)); + } + + return onlineAccounts; + } +}