Skip to content

Commit

Permalink
Backport curves changes: drbg, truncateHash, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmillr committed Jan 8, 2023
1 parent bcfb393 commit b161b1c
Showing 1 changed file with 102 additions and 61 deletions.
163 changes: 102 additions & 61 deletions index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ const CURVE = Object.freeze({
Gx: BigInt('55066263022277343669578718895168534326250603453777594175500187360389116729240'),
Gy: BigInt('32670510020758816978083085130507043184471273380659243275938904335757337482424'),

// For endomorphism, see below
// Legacy, endo params are defined below
beta: BigInt('0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee'),
});

const divNearest = (a: bigint, b: bigint) => (a + b / _2n) / b;
// Endomorphism params
const endo = {
beta: BigInt('0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee'),
// Split 256-bit K into 2 128-bit (k1, k2) for which k1 + k2 * lambda = K.
Expand All @@ -60,8 +61,13 @@ const endo = {
return { k1neg, k1, k2neg, k2 };
},
};
const fieldLen = 32;
const groupLen = 32;

// Placeholder for non-sha256 hashes
const fieldLen = 32; // Field element: their range is 0 to CURVE.P
const groupLen = 32; // Group element: their range is 1 to CURVE.n
const hashLen = 32; // Hash used with secp256k1, sha2-256
const compressedLen = fieldLen + 1; // DER-encoded field element
const uncompressedLen = 2 * fieldLen + 1; // DER-encoded pair of field elements

// Cleaner js output if that's on a separate line.
export { CURVE };
Expand All @@ -81,7 +87,7 @@ function weierstrass(x: bigint): bigint {
type Hex = Uint8Array | string;
// Very few implementations accept numbers, we do it to ease learning curve
type PrivKey = Hex | bigint | number;
// 33/65-byte ECDSA key, or 32-byte Schnorr key - not interchangeable
// compressed/uncompressed ECDSA key, or Schnorr key - not interchangeable
type PubKey = Hex | Point;
// ECDSA signature
type Sig = Hex | Signature;
Expand All @@ -105,6 +111,10 @@ class ShaError extends Error {
}
}

function assertJacPoint(other: unknown) {
if (!(other instanceof JacobianPoint)) throw new TypeError('JacobianPoint expected');
}

/**
* Jacobian Point works in 3d / jacobi coordinates: (x, y, z) ∋ (x=x/z², y=y/z³)
* Default Point works in 2d / affine coordinates: (x, y)
Expand Down Expand Up @@ -142,7 +152,7 @@ class JacobianPoint {
* Compare one point to another.
*/
equals(other: JacobianPoint): boolean {
if (!(other instanceof JacobianPoint)) throw new TypeError('JacobianPoint expected');
assertJacPoint(other);
const { x: X1, y: Y1, z: Z1 } = this;
const { x: X2, y: Y2, z: Z2 } = other;
const Z1Z1 = mod(Z1 * Z1);
Expand Down Expand Up @@ -185,7 +195,7 @@ class JacobianPoint {
// Cost: 12M + 4S + 6add + 1*2
// Note: 2007 Bernstein-Lange (11M + 5S + 9add + 4*2) is actually 10% slower.
add(other: JacobianPoint): JacobianPoint {
if (!(other instanceof JacobianPoint)) throw new TypeError('JacobianPoint expected');
assertJacPoint(other);
const { x: X1, y: Y1, z: Z1 } = this;
const { x: X2, y: Y2, z: Z2 } = other;
if (X2 === _0n || Y2 === _0n) return this;
Expand Down Expand Up @@ -453,8 +463,8 @@ export class Point {
}

/**
* Supports compressed Schnorr (32-byte) and ECDSA (33-byte) points
* @param bytes 32/33 bytes
* Supports compressed Schnorr and ECDSA points
* @param bytes
* @returns Point instance
*/
private static fromCompressedHex(bytes: Uint8Array) {
Expand Down Expand Up @@ -488,19 +498,22 @@ export class Point {

/**
* Converts hash string or Uint8Array to Point.
* @param hex 32-byte (schnorr) or 33/65-byte (ECDSA) hex
* @param hex schnorr or ECDSA hex
*/
static fromHex(hex: Hex): Point {
const bytes = ensureBytes(hex);
const len = bytes.length;
const header = bytes[0];
// this.assertValidity() is done inside of those two functions
if (len === 32 || (len === 33 && (header === 0x02 || header === 0x03))) {
// Schnorr
if (len === fieldLen) return this.fromCompressedHex(bytes);
// ECDSA
if (len === compressedLen && (header === 0x02 || header === 0x03)) {
return this.fromCompressedHex(bytes);
}
if (len === 65 && header === 0x04) return this.fromUncompressedHex(bytes);
if (len === uncompressedLen && header === 0x04) return this.fromUncompressedHex(bytes);
throw new Error(
`Point.fromHex: received invalid point. Expected 32-33 compressed bytes or 65 uncompressed bytes, not ${len}`
`Point.fromHex: received invalid point. Expected 32-${compressedLen} compressed bytes or ${uncompressedLen} uncompressed bytes, not ${len}`
);
}

Expand Down Expand Up @@ -700,7 +713,7 @@ export class Signature {
}

normalizeS(): Signature {
return this.hasHighS() ? new Signature(this.r, CURVE.n - this.s) : this;
return this.hasHighS() ? new Signature(this.r, mod(-this.s, CURVE.n)) : this;
}

// DER-encoded
Expand All @@ -710,9 +723,11 @@ export class Signature {
toDERHex() {
const sHex = sliceDER(numberToHexUnpadded(this.s));
const rHex = sliceDER(numberToHexUnpadded(this.r));
const rLen = numberToHexUnpadded(rHex.length / 2);
const sLen = numberToHexUnpadded(sHex.length / 2);
const length = numberToHexUnpadded(rHex.length / 2 + sHex.length / 2 + 4);
const sHexL = sHex.length / 2;
const rHexL = rHex.length / 2;
const sLen = numberToHexUnpadded(sHexL);
const rLen = numberToHexUnpadded(rHexL);
const length = numberToHexUnpadded(rHexL + sHexL + 4);
return `30${length}02${rLen}${rHex}02${sLen}${sHex}`;
}

Expand Down Expand Up @@ -922,15 +937,19 @@ function invertBatch(nums: bigint[], p: bigint = CURVE.P): bigint[] {
return scratch;
}

// Can be replaced by bytesToNumber(). Placeholder for non-sha256 hashes
function bits2int_2(bytes: Uint8Array) {
const delta = bytes.length * 8 - groupLen * 8; // 256-256=0 for sha256/secp256k1
const num = bytesToNumber(bytes);
return delta > 0 ? num >> BigInt(delta) : num;
}

// Ensures ECDSA message hashes are 32 bytes and < curve order
function truncateHash(hash: Uint8Array): bigint {
function truncateHash(hash: Uint8Array, truncateOnly = false): bigint {
const h = bits2int_2(hash);
if (truncateOnly) return h;
const { n } = CURVE;
const byteLength = hash.length;
const delta = byteLength * 8 - 256; // size of curve.n
let h = bytesToNumber(hash);
if (delta > 0) h = h >> BigInt(delta);
if (h >= n) h -= n;
return h;
return h >= n ? h - n : h;
}

// RFC6979 related code
Expand All @@ -948,10 +967,13 @@ class HmacDrbg {
k: Uint8Array;
v: Uint8Array;
counter: number;
constructor() {
// Step B, Step C
this.v = new Uint8Array(32).fill(1);
this.k = new Uint8Array(32).fill(0);
constructor(public hashLen: number, public qByteLen: number) {
if (typeof hashLen !== 'number' || hashLen < 2) throw new Error('hashLen must be a number');
if (typeof qByteLen !== 'number' || qByteLen < 2) throw new Error('qByteLen must be a number');

// Step B, Step C: set hashLen to 8*ceil(hlen/8)
this.v = new Uint8Array(hashLen).fill(1);
this.k = new Uint8Array(hashLen).fill(0);
this.counter = 0;
}
private hmac(...values: Uint8Array[]) {
Expand Down Expand Up @@ -987,24 +1009,40 @@ class HmacDrbg {

async generate(): Promise<Uint8Array> {
this.incr();
this.v = await this.hmac(this.v);
return this.v;
let len = 0;
const out: Uint8Array[] = [];
while (len < this.qByteLen) {
this.v = await this.hmac(this.v);
const sl = this.v.slice();
out.push(sl);
len += this.v.length;
}
return concatBytes(...out);
}
generateSync(): Uint8Array {
this.checkSync();
this.incr();
this.v = this.hmacSync(this.v);
return this.v;
let len = 0;
const out: Uint8Array[] = [];
while (len < this.qByteLen) {
this.v = this.hmacSync(this.v);
const sl = this.v.slice();
out.push(sl);
len += this.v.length;
}
return concatBytes(...out);
}
// There is no need in clean() method
// It's useless, there are no guarantees with JS GC
// whether bigints are removed even if you clean Uint8Arrays.
}

// Valid scalars are [1, n-1]
function isWithinCurveOrder(num: bigint): boolean {
return _0n < num && num < CURVE.n;
}

// Valid field elements are [1, p-1]
function isValidFieldElement(num: bigint): boolean {
return _0n < num && num < CURVE.P;
}
Expand All @@ -1017,20 +1055,27 @@ function isValidFieldElement(num: bigint): boolean {
* @param d private key
* @returns Signature with its point on curve Q OR undefined if params were invalid
*/
function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint): RecoveredSig | undefined {
const k = bytesToNumber(kBytes);
function kmdToSig(kBytes: Uint8Array, m: bigint, d: bigint, lowS = true): RecoveredSig | undefined {
const { n } = CURVE;
const k = truncateHash(kBytes, true);
if (!isWithinCurveOrder(k)) return;
// Important: all mod() calls in the function must be done over `n`
const { n } = CURVE;
const kinv = invert(k, n);
const q = Point.BASE.multiply(k);
// r = x mod n
const r = mod(q.x, n);
if (r === _0n) return;
// s = (1/k * (m + dr) mod n
const s = mod(invert(k, n) * mod(m + d * r, n), n);
// s = (m + dr)/k mod n where x/k == x*inv(k)
const s = mod(kinv * mod(m + d * r, n), n);
if (s === _0n) return;
const sig = new Signature(r, s);
const recovery = (q.x === sig.r ? 0 : 2) | Number(q.y & _1n);

// Recovery bit adjustment
let sig = new Signature(r, s);
let recovery = (q.x === sig.r ? 0 : 2) | Number(q.y & _1n);
if (lowS && sig.hasHighS()) {
sig = sig.normalizeS();
recovery ^= 1;
}
return { sig, recovery };
}

Expand Down Expand Up @@ -1085,7 +1130,7 @@ function normalizeSignature(signature: Sig): Signature {
/**
* Computes public key for secp256k1 private key.
* @param privateKey 32-byte private key
* @param isCompressed whether to return compact (33-byte), or full (65-byte) key
* @param isCompressed whether to return compact, or full key
* @returns Public key, full by default; short when isCompressed=true
*/
export function getPublicKey(privateKey: PrivKey, isCompressed = false): Uint8Array {
Expand All @@ -1097,7 +1142,7 @@ export function getPublicKey(privateKey: PrivKey, isCompressed = false): Uint8Ar
* @param msgHash message hash
* @param signature DER or compact sig
* @param recovery 0 or 1
* @param isCompressed whether to return compact (33-byte), or full (65-byte) key
* @param isCompressed whether to return compact, or full key
* @returns Public key, full by default; short when isCompressed=true
*/
export function recoverPublicKey(
Expand All @@ -1116,8 +1161,8 @@ function isProbPub(item: PrivKey | PubKey): boolean {
const arr = item instanceof Uint8Array;
const str = typeof item === 'string';
const len = (arr || str) && (item as Hex).length;
if (arr) return len === 33 || len === 65;
if (str) return len === 66 || len === 130;
if (arr) return len === compressedLen || len === uncompressedLen;
if (str) return len === compressedLen * 2 || len === uncompressedLen * 2;
if (item instanceof Point) return true;
return false;
}
Expand All @@ -1128,7 +1173,7 @@ function isProbPub(item: PrivKey | PubKey): boolean {
* 2. Checks for the public key of being on-curve
* @param privateA private key
* @param publicB different public key
* @param isCompressed whether to return compact (33-byte), or full (65-byte) key
* @param isCompressed whether to return compact, or full key
* @returns shared public key
*/
export function getSharedSecret(
Expand All @@ -1152,7 +1197,7 @@ type SignOutput = Uint8Array | [Uint8Array, number];

// RFC6979 methods
function bits2int(bytes: Uint8Array) {
const slice = bytes.length > 32 ? bytes.slice(0, 32) : bytes;
const slice = bytes.length > fieldLen ? bytes.slice(0, fieldLen) : bytes;
return bytesToNumber(slice);
}
function bits2octets(bytes: Uint8Array): Uint8Array {
Expand All @@ -1175,9 +1220,9 @@ function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy)
const seedArgs = [int2octets(d), bits2octets(h1)];
// RFC6979 3.6: additional k' could be provided
if (extraEntropy != null) {
if (extraEntropy === true) extraEntropy = utils.randomBytes(32);
if (extraEntropy === true) extraEntropy = utils.randomBytes(fieldLen);
const e = ensureBytes(extraEntropy);
if (e.length !== 32) throw new Error('sign: Expected 32 bytes of extra data');
if (e.length !== fieldLen) throw new Error(`sign: Expected ${fieldLen} bytes of extra data`);
seedArgs.push(e);
}
// seed is constructed from private key and message
Expand All @@ -1191,12 +1236,8 @@ function initSigArgs(msgHash: Hex, privateKey: PrivKey, extraEntropy?: Entropy)
// Takes signature with its recovery bit, normalizes it
// Produces DER/compact signature and proper recovery bit
function finalizeSig(recSig: RecoveredSig, opts: OptsNoRecov | OptsRecov): SignOutput {
let { sig, recovery } = recSig;
const { canonical, der, recovered } = Object.assign({ canonical: true, der: true }, opts);
if (canonical && sig.hasHighS()) {
sig = sig.normalizeS();
recovery ^= 1;
}
const { sig, recovery } = recSig;
const { der, recovered } = Object.assign({ canonical: true, der: true }, opts);
const hashed = der ? sig.toDERRawBytes() : sig.toCompactRawBytes();
return recovered ? [hashed, recovery] : hashed;
}
Expand Down Expand Up @@ -1228,10 +1269,10 @@ async function sign(msgHash: Hex, privKey: PrivKey, opts: Opts = {}): Promise<Si
const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy);
let sig: RecoveredSig | undefined;
// Steps B, C, D, E, F, G
const drbg = new HmacDrbg();
const drbg = new HmacDrbg(hashLen, groupLen);
await drbg.reseed(seed);
// Step H3, repeat until k is in range [1, n-1]
while (!(sig = kmdToSig(await drbg.generate(), m, d))) await drbg.reseed();
while (!(sig = kmdToSig(await drbg.generate(), m, d, opts.canonical))) await drbg.reseed();
return finalizeSig(sig, opts);
}

Expand All @@ -1247,10 +1288,10 @@ function signSync(msgHash: Hex, privKey: PrivKey, opts: Opts = {}): SignOutput {
const { seed, m, d } = initSigArgs(msgHash, privKey, opts.extraEntropy);
let sig: RecoveredSig | undefined;
// Steps B, C, D, E, F, G
const drbg = new HmacDrbg();
const drbg = new HmacDrbg(hashLen, groupLen);
drbg.reseedSync(seed);
// Step H3, repeat until k is in range [1, n-1]
while (!(sig = kmdToSig(drbg.generateSync(), m, d))) drbg.reseedSync();
while (!(sig = kmdToSig(drbg.generateSync(), m, d, opts.canonical))) drbg.reseedSync();
return finalizeSig(sig, opts);
}
export { sign, signSync };
Expand Down Expand Up @@ -1521,7 +1562,7 @@ export const utils = {
_normalizePrivateKey: normalizePrivateKey,

/**
* Can take 40 or more bytes of uniform input e.g. from CSPRNG or KDF
* Can take (n+8) or more bytes of uniform input e.g. from CSPRNG or KDF
* and convert them into private key, with the modulo bias being neglible.
* As per FIPS 186 B.4.1.
* https://research.kudelskisecurity.com/2020/07/28/the-definitive-guide-to-modulo-bias-and-how-to-avoid-it/
Expand All @@ -1530,9 +1571,9 @@ export const utils = {
*/
hashToPrivateKey: (hash: Hex): Uint8Array => {
hash = ensureBytes(hash);
const minLen = fieldLen + 8;
const minLen = groupLen + 8;
if (hash.length < minLen || hash.length > 1024) {
throw new Error(`Expected ${minLen}-1024 bytes of private key as per FIPS 186`);
throw new Error(`Expected valid bytes of private key as per FIPS 186`);
}
const num = mod(bytesToNumber(hash), CURVE.n - _1n) + _1n;
return numTo32b(num);
Expand All @@ -1550,8 +1591,8 @@ export const utils = {
},

// Takes curve order + 64 bits from CSPRNG
// so that modulo bias is neglible, matches FIPS 186 B.1.1.
randomPrivateKey: (): Uint8Array => utils.hashToPrivateKey(utils.randomBytes(fieldLen + 8)),
// so that modulo bias is neglible, matches FIPS 186 B.4.1.
randomPrivateKey: (): Uint8Array => utils.hashToPrivateKey(utils.randomBytes(groupLen + 8)),

/**
* 1. Returns cached point which you can use to pass to `getSharedSecret` or `#multiply` by it.
Expand Down

0 comments on commit b161b1c

Please sign in to comment.