From 710813d8240cf4705f0412afb070f228569643d6 Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:11:49 -0300 Subject: [PATCH] Implement strauss shamir trick (#171) * wip * adds shamir trick for the linear combination * remove unused functions * remove unused functions * remove alternative add implementation --- precompiles/P256VERIFY.yul | 182 ++++++++++++++++++++----------------- 1 file changed, 99 insertions(+), 83 deletions(-) diff --git a/precompiles/P256VERIFY.yul b/precompiles/P256VERIFY.yul index 07a51f36..57caccf6 100644 --- a/precompiles/P256VERIFY.yul +++ b/precompiles/P256VERIFY.yul @@ -115,13 +115,6 @@ object "P256VERIFY" { precompileCall(0, gas()) } - /// @notice Checks if the LSB of a number is 1. - /// @param x The number to check. - /// @return ret True if the LSB is 1, false otherwise. - function lsbIsOne(x) -> ret { - ret := and(x, 1) - } - // MONTGOMERY /// @notice Computes the inverse in Montgomery Form of a number in Montgomery Form. @@ -392,22 +385,6 @@ object "P256VERIFY" { zr := MONTGOMERY_ONE_P() } - /// @notice Converts a point in projective coordinates to affine coordinates in Montgomery form for modulus P(). - /// @dev See https://www.nayuki.io/page/elliptic-curve-point-addition-in-projective-coordinates for further details. - /// @dev Reverts if the point is not on the curve. - /// @param xp The x coordinate of the point P in projective coordinates in Montgomery form. - /// @param yp The y coordinate of the point P in projective coordinates in Montgomery form. - /// @param zp The z coordinate of the point P in projective coordinates in Montgomery form. - /// @return xr The x coordinate of the point P in affine coordinates in Montgomery form. - /// @return yr The y coordinate of the point P in affine coordinates in Montgomery form. - function projectiveIntoAffine(xp, yp, zp) -> xr, yr { - if zp { - let zp_inv := montgomeryModularInverse(zp, P(), R2_MOD_P()) - xr := montgomeryMul(xp, zp_inv, P(), P_PRIME()) - yr := montgomeryMul(yp, zp_inv, P(), P_PRIME()) - } - } - /// @notice Checks if a point in projective coordinates is the point at infinity. /// @dev The point at infinity is defined as the point (0, 0, 0). /// @param xp The x coordinate of the point P in projective coordinates in Montgomery form. @@ -460,7 +437,6 @@ object "P256VERIFY" { /// @return yr The y coordinate of the point P + Q in projective coordinates in Montgomery form. /// @return zr The z coordinate of the point P + Q in projective coordinates in Montgomery form. function projectiveAdd(xp, yp, zp, xq, yq, zq) -> xr, yr, zr { - let flag := 1 let qIsInfinity := projectivePointIsInfinity(xq, yq, zq) let pIsInfinity := projectivePointIsInfinity(xp, yp, zp) if and(pIsInfinity, qIsInfinity) { @@ -468,84 +444,129 @@ object "P256VERIFY" { xr := 0 yr := MONTGOMERY_ONE_P() zr := 0 - flag := 0 + leave } - if and(flag, pIsInfinity) { - // Infinity + P = P + if pIsInfinity { + // Infinity + Q = Q xr := xq yr := yq zr := zq - flag := 0 + leave } - if and(flag, qIsInfinity) { + if qIsInfinity { // P + Infinity = P xr := xp yr := yp zr := zp - flag := 0 + leave } - if and(flag, and(and(eq(xp, xq), eq(montgomerySub(0, yp, P()), yq)), eq(zp, zq))) { + // FIX ME: we need to check xp/zp == xq/zq and yp/zp == -yq/zq + if and(and(eq(xp, xq), eq(montgomerySub(0, yp, P()), yq)), eq(zp, zq)) { // P + (-P) = Infinity xr := 0 yr := MONTGOMERY_ONE_P() zr := 0 - flag := 0 + leave } - if and(flag, and(and(eq(xp, xq), eq(yp, yq)), eq(zp, zq))) { + // FIX ME: we need to check xp/zp == xq/zq and yp/zp == yq/zq + if and(and(eq(xp, xq), eq(yp, yq)), eq(zp, zq)) { // P + P = 2P xr, yr, zr := projectiveDouble(xp, yp, zp) - flag := 0 + leave } // P1 + P2 = P3 - if flag { - let t0 := montgomeryMul(yq, zp, P(), P_PRIME()) - let t1 := montgomeryMul(yp, zq, P(), P_PRIME()) - let t := montgomerySub(t0, t1, P()) - let u0 := montgomeryMul(xq, zp, P(), P_PRIME()) - let u1 := montgomeryMul(xp, zq, P(), P_PRIME()) - let u := montgomerySub(u0, u1, P()) - let u2 := montgomeryMul(u, u, P(), P_PRIME()) - let u3 := montgomeryMul(u2, u, P(), P_PRIME()) - let v := montgomeryMul(zq, zp, P(), P_PRIME()) - let w := montgomerySub(montgomeryMul(montgomeryMul(t, t, P(), P_PRIME()), v, P(), P_PRIME()), montgomeryMul(u2, montgomeryAdd(u0, u1, P()), P(), P_PRIME()), P()) - - xr := montgomeryMul(u, w, P(), P_PRIME()) - yr := montgomerySub(montgomeryMul(t, montgomerySub(montgomeryMul(u0, u2, P(), P_PRIME()), w, P()), P(), P_PRIME()), montgomeryMul(t0, u3, P(), P_PRIME()), P()) - zr := montgomeryMul(u3, v, P(), P_PRIME()) - } + let t0 := montgomeryMul(yp, zq, P(), P_PRIME()) + let t1 := montgomeryMul(yq, zp, P(), P_PRIME()) + let t := montgomerySub(t0, t1, P()) + let u0 := montgomeryMul(xp, zq, P(), P_PRIME()) + let u1 := montgomeryMul(xq, zp, P(), P_PRIME()) + let u := montgomerySub(u0, u1, P()) + let u2 := montgomeryMul(u, u, P(), P_PRIME()) + let u3 := montgomeryMul(u2, u, P(), P_PRIME()) + let v := montgomeryMul(zp, zq, P(), P_PRIME()) + let w := montgomerySub(montgomeryMul(montgomeryMul(t, t, P(), P_PRIME()), v, P(), P_PRIME()), montgomeryMul(u2, montgomeryAdd(u0, u1, P()), P(), P_PRIME()), P()) + + xr := montgomeryMul(u, w, P(), P_PRIME()) + yr := montgomerySub(montgomeryMul(t, montgomerySub(montgomeryMul(u0, u2, P(), P_PRIME()), w, P()), P(), P_PRIME()), montgomeryMul(t0, u3, P(), P_PRIME()), P()) + zr := montgomeryMul(u3, v, P(), P_PRIME()) } - /// @notice Computes the scalar multiplication of a point in projective coordinates in Montgomery form for modulus P(). - /// @param xp The x coordinate of the point P in projective coordinates in Montgomery form. - /// @param yp The y coordinate of the point P in projective coordinates in Montgomery form. - /// @param zp The z coordinate of the point P in projective coordinates in Montgomery form. - /// @param scalar The scalar to multiply the point by. - /// @return xr The x coordinate of the point scalar*P in projective coordinates in Montgomery form. - /// @return yr The y coordinate of the point scalar*P in projective coordinates in Montgomery form. - /// @return zr The z coordinate of the point scalar*P in projective coordinates in Montgomery form. - function projectiveScalarMul(xp, yp, zp, scalar) -> xr, yr, zr { - switch eq(scalar, 2) - case 0 { - let xq := xp - let yq := yp - let zq := zp - xr := 0 - yr := MONTGOMERY_ONE_P() - zr := 0 - for {} scalar {} { - if lsbIsOne(scalar) { - xr, yr, zr := projectiveAdd(xr, yr, zr, xq, yq, zq) - } - - xq, yq, zq := projectiveDouble(xq, yq, zq) - // Check next bit - scalar := shr(1, scalar) + /// @notice Computes the linear combination of curve points: t0*Q + t1*G = R + /// @param xq The x coordinate of the point Q in projective coordinates in Montgomery form. + /// @param yq The y coordinate of the point Q in projective coordinates in Montgomery form. + /// @param zq The z coordinate of the point Q in projective coordinates in Montgomery form. + /// @param t0 The scalar to multiply the point Q by. + /// @param t1 The scalar to multiply the generator G by. + /// @return xr The x coordinate of the resulting point R in projective coordinates in Montgomery form. + /// @return yr The y coordinate of the resulting point R in projective coordinates in Montgomery form. + /// @return zr The z coordinate of the resulting point R in projective coordinates in Montgomery form. + function shamirLinearCombination(xq, yq, zq, t0, t1) -> xr, yr, zr { + let xg, yg, zg := MONTGOMERY_PROJECTIVE_G_P() + let xh, yh, zh := projectiveAdd(xg, yg, zg, xq, yq, zq) + let index, ret := findMostSignificantBitIndex(t0, t1) + switch ret + case 1 { + xr := xq + yr := yq + zr := zq + } + case 2 { + xr := xg + yr := yg + zr := zg + } + case 3 { + xr := xh + yr := yh + zr := zh + } + let ret + for {} gt(index, 0) {} { + index := sub(index, 1) + xr, yr, zr := projectiveDouble(xr, yr, zr) + ret := compareBits(index, t0, t1) + switch ret + case 1 { + xr, yr, zr := projectiveAdd(xr, yr, zr, xq, yq, zq) + } + case 2 { + xr, yr, zr := projectiveAdd(xr, yr, zr, xg, yg, zg) + } + case 3 { + xr, yr, zr := projectiveAdd(xr, yr, zr, xh, yh, zh) } } - case 1 { - xr, yr, zr := projectiveDouble(xp, yp, zp) + } + + /// @notice Computes the largest index of the most significant bit among two scalars, + /// @notice and indicates which scalar it belongs to. + /// @param t0 The first scalar. + /// @param t1 The second scalar. + /// @return index The position of the most significant bit among t0 and t1. + /// @return ret Indicates which scalar the most significant bit belongs to. + /// @return return 1 if it belongs to t1, returns 2 if it belongs to t0, + /// @return and returns 3 if both have the most significant bit in the same position. + function findMostSignificantBitIndex(t0, t1) -> index, ret { + index := 255 + ret := 0 + for {} eq(ret, 0) { index := sub(index, 1) } { + ret := compareBits(index, t0, t1) } + index := add(index, 1) + } + + /// @notice Compares the bits between two scalars at a specific position. + /// @param index The position to compare. + /// @param t0 The first scalar. + /// @param t1 The second scalar. + /// @return ret A value that indicates the value of the bits. + /// @return ret is 0 if both are 0. + /// @return ret is 1 if the bit of t0 is 0 and the bit of t1 is 1. + /// @return ret is 2 if the bit of t0 is 1 and the bit of t1 is 0. + /// @return ret is 3 if both bits are 1. + function compareBits(index, t0, t1) -> ret { + ret := add(mul(and(shr(index, t0), 1), 2), and(shr(index, t1), 1)) } // Fallback @@ -584,12 +605,7 @@ object "P256VERIFY" { let t0 := outOfMontgomeryForm(montgomeryMul(hash, s1, N(), N_PRIME()), N(), N_PRIME()) let t1 := outOfMontgomeryForm(montgomeryMul(r, s1, N(), N_PRIME()), N(), N_PRIME()) - let gx, gy, gz := MONTGOMERY_PROJECTIVE_G_P() - - // TODO: Implement Shamir's trick for adding to scalar multiplications faster. - let xp, yp, zp := projectiveScalarMul(gx, gy, gz, t0) - let xq, yq, zq := projectiveScalarMul(x, y, z, t1) - let xr, yr, zr := projectiveAdd(xp, yp, zp, xq, yq, zq) + let xr, yr, zr := shamirLinearCombination(x, y, z, t0, t1) // As we only need xr in affine form, we can skip transforming the `y` coordinate. let z_inv := montgomeryModularInverse(zr, P(), R2_MOD_P())