From 9cd1a82a61c8e7555aa2441cfaee29a2da15959e Mon Sep 17 00:00:00 2001 From: therealyingtong Date: Sat, 21 Jan 2023 19:03:53 -0500 Subject: [PATCH] Add transcript --- curve.py | 46 +++++ prover.py | 482 +++++++++++++++++++++++++++++++++----------------- setup.py | 20 +-- test.py | 8 +- transcript.py | 28 +++ verifier.py | 180 ++++++++++--------- 6 files changed, 501 insertions(+), 263 deletions(-) create mode 100644 curve.py create mode 100644 transcript.py diff --git a/curve.py b/curve.py new file mode 100644 index 0000000..9f3ed92 --- /dev/null +++ b/curve.py @@ -0,0 +1,46 @@ +from py_ecc.fields.field_elements import FQ as Field +import py_ecc.bn128 as b +from typing import NewType +from functools import cache +from utils import lincomb +from dataclasses import dataclass + +primitive_root = 5 + +class Scalar(Field): + field_modulus = b.curve_order + + # Gets the first root of unity of a given group order + @classmethod + def root_of_unity(cls, group_order:int): + return Scalar(5) ** ((cls.field_modulus - 1) // group_order) + + # Gets the full list of roots of unity of a given group order + @classmethod + def roots_of_unity(cls, group_order: int): + o = [Scalar(1), cls.root_of_unity(group_order)] + while len(o) < group_order: + o.append(o[-1] * o[1]) + return o + +Base = NewType('Base', b.FQ) + +def ec_mul(pt, coeff): + if hasattr(coeff, 'n'): + coeff = coeff.n + return b.multiply(pt, coeff % b.curve_order) + +# Elliptic curve linear combination. A truly optimized implementation +# would replace this with a fast lin-comb algo, see https://ethresear.ch/t/7238 +def ec_lincomb(pairs): + return lincomb( + [pt for (pt, _) in pairs], + [int(n) % b.curve_order for (_, n) in pairs], + b.add, + b.Z1 + ) + # Equivalent to: + # o = b.Z1 + # for pt, coeff in pairs: + # o = b.add(o, ec_mul(pt, coeff)) + # return o diff --git a/prover.py b/prover.py index 3e4389b..877f87e 100644 --- a/prover.py +++ b/prover.py @@ -4,151 +4,233 @@ from setup import * from typing import Optional from dataclasses import dataclass +from transcript import Transcript +from curve import Scalar @dataclass -class Proof: - # [a(x)]₁ (commitment to left wire polynomial) - a_1: G1Point - # [b(x)]₁ (commitment to right wire polynomial) - b_1: G1Point - # [c(x)]₁ (commitment to output wire polynomial) - c_1: G1Point - # [z(x)]₁ (commitment to permutation polynomial) - z_1: G1Point - # [t_lo(x)]₁ (commitment to t_lo(X), the low chunk of the quotient polynomial t(X)) - t_lo_1: G1Point - # [t_mid(x)]₁ (commitment to t_mid(X), the middle chunk of the quotient polynomial t(X)) - t_mid_1: G1Point - # [t_hi(x)]₁ (commitment to t_hi(X), the high chunk of the quotient polynomial t(X)) - t_hi_1: G1Point - # Evaluation of a(X) at evaluation challenge ζ - a_eval: f_inner - # Evaluation of b(X) at evaluation challenge ζ - b_eval: f_inner - # Evaluation of c(X) at evaluation challenge ζ - c_eval: f_inner - # Evaluation of the first permutation polynomial S_σ1(X) at evaluation challenge ζ - s1_eval: f_inner - # Evaluation of the second permutation polynomial S_σ2(X) at evaluation challenge ζ - s2_eval: f_inner - # Evaluation of the shifted permutation polynomial z(X) at the shifted evaluation challenge ζω - z_shifted_eval: f_inner - # [W_ζ(X)]₁ (commitment to the opening proof polynomial) - W_z_1: G1Point - # [W_ζω(X)]₁ (commitment to the opening proof polynomial) - W_zw_1: G1Point - - @classmethod - def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Optional[str], int]): - group_order = program.group_order +class Prover: + def prove( + self, + setup: Setup, + program: Program, + witness: dict[Optional[str], int] + ): + self.group_order = program.group_order + proof = {} + + # Initialise Fiat-Shamir transcript + transcript = Transcript() + + # Collect fixed and public information + self.init(program, witness) + + # Round 1 + # - [a(x)]₁ (commitment to left wire polynomial) + # - [b(x)]₁ (commitment to right wire polynomial) + # - [c(x)]₁ (commitment to output wire polynomial) + a_1, b_1, c_1 = self.round_1(program, witness, transcript, setup) + proof['a_1'] = a_1 + proof['b_1'] = b_1 + proof['c_1'] = c_1 + + # Round 2 + # - [z(x)]₁ (commitment to permutation polynomial) + z_1 = self.round_2(transcript, setup) + proof['z_1'] = z_1 + + # Round 3 + # - [t_lo(x)]₁ (commitment to t_lo(X), the low chunk of the quotient polynomial t(X)) + # - [t_mid(x)]₁ (commitment to t_mid(X), the middle chunk of the quotient polynomial t(X)) + # - [t_hi(x)]₁ (commitment to t_hi(X), the high chunk of the quotient polynomial t(X)) + t_lo_1, t_mid_1, t_hi_1 = self.round_3(transcript, setup) + proof['t_lo_1'] = t_lo_1 + proof['t_mid_1'] = t_mid_1 + proof['t_hi_1'] = t_hi_1 + + # Round 4 + # - Evaluation of a(X) at evaluation challenge ζ + # - Evaluation of b(X) at evaluation challenge ζ + # - Evaluation of c(X) at evaluation challenge ζ + # - Evaluation of the first permutation polynomial S_σ1(X) at evaluation challenge ζ + # - Evaluation of the second permutation polynomial S_σ2(X) at evaluation challenge ζ + # - Evaluation of the shifted permutation polynomial z(X) at the shifted evaluation challenge ζω + a_eval, b_eval, c_eval, s1_eval, s2_eval, z_shifted_eval = self.round_4(transcript) + proof['a_eval'] = a_eval + proof['b_eval'] = b_eval + proof['c_eval'] = c_eval + proof['s1_eval'] = s1_eval + proof['s2_eval'] = s2_eval + proof['z_shifted_eval'] = z_shifted_eval + + # Round 5 + # - [W_ζ(X)]₁ (commitment to the opening proof polynomial) + # - [W_ζω(X)]₁ (commitment to the opening proof polynomial) + W_z_1, W_zw_1 = self.round_5(transcript, setup) + proof['W_z_1'] = W_z_1 + proof['W_zw_1'] = W_zw_1 + + return proof + + def init( + self, + program: Program, + witness: dict[Optional[str], int] + ): + group_order = self.group_order + + QL, QR, QM, QO, QC = program.make_gate_polynomials() + # Compute the accumulator polynomial for the permutation arguments + S = program.make_s_polynomials() + S1 = S[Column.LEFT] + S2 = S[Column.RIGHT] + S3 = S[Column.OUTPUT] + + public_vars = program.get_public_assignments() + PI = ( + [Scalar(-witness[v]) for v in public_vars] + + [Scalar(0) for _ in range(group_order - len(public_vars))] + ) + + self.QL = QL + self.QR = QR + self.QM = QM + self.QO = QO + self.QC = QC + self.S1 = S1 + self.S2 = S2 + self.S3 = S3 + self.PI = PI + + def round_1( + self, + program: Program, + witness: dict[Optional[str], int], + transcript: Transcript, setup: Setup + ): + group_order = self.group_order if None not in witness: witness[None] = 0 # Compute wire assignments - A = [f_inner(0) for _ in range(group_order)] - B = [f_inner(0) for _ in range(group_order)] - C = [f_inner(0) for _ in range(group_order)] + A = [Scalar(0) for _ in range(group_order)] + B = [Scalar(0) for _ in range(group_order)] + C = [Scalar(0) for _ in range(group_order)] for i, gate_wires in enumerate(program.wires()): - A[i] = f_inner(witness[gate_wires.L]) - B[i] = f_inner(witness[gate_wires.R]) - C[i] = f_inner(witness[gate_wires.O]) - a_1 = setup.evaluations_to_point(A) - b_1 = setup.evaluations_to_point(B) - c_1 = setup.evaluations_to_point(C) + A[i] = Scalar(witness[gate_wires.L]) + B[i] = Scalar(witness[gate_wires.R]) + C[i] = Scalar(witness[gate_wires.O]) - public_vars = program.get_public_assignments() - PI = ( - [f_inner(-witness[v]) for v in public_vars] + - [f_inner(0) for _ in range(group_order - len(public_vars))] - ) + a_1 = setup.commit(A) + transcript.hash_point(a_1) + + b_1 = setup.commit(B) + transcript.hash_point(b_1) + + c_1 = setup.commit(C) + transcript.hash_point(c_1) + + self.A = A + self.B = B + self.C = C + + # Sanity check that witness fulfils gate constraints + for i in range(group_order): + assert ( + A[i] * self.QL[i] + B[i] * self.QR[i] + A[i] * B[i] * self.QM[i] + + C[i] * self.QO[i] + self.PI[i] + self.QC[i] == 0 + ) - buf = serialize_point(a_1) + serialize_point(b_1) + serialize_point(c_1) + return a_1, b_1, c_1 + + def round_2( + self, + transcript: Transcript, + setup: Setup, + ): + group_order = self.group_order # The first two Fiat-Shamir challenges - beta = binhash_to_f_inner(keccak256(buf)) - gamma = binhash_to_f_inner(keccak256(keccak256(buf))) + beta = transcript.squeeze() + transcript.beta = beta + self.beta = beta - # Compute the accumulator polynomial for the permutation arguments - S = program.make_s_polynomials() - S1 = S[Column.LEFT] - S2 = S[Column.RIGHT] - S3 = S[Column.OUTPUT] - Z = [f_inner(1)] + gamma = transcript.squeeze() + transcript.gamma = gamma + self.gamma = gamma + + Z = [Scalar(1)] roots_of_unity = get_roots_of_unity(group_order) for i in range(group_order): Z.append( Z[-1] * - (A[i] + beta * roots_of_unity[i] + gamma) * - (B[i] + beta * 2 * roots_of_unity[i] + gamma) * - (C[i] + beta * 3 * roots_of_unity[i] + gamma) / - (A[i] + beta * S1[i] + gamma) / - (B[i] + beta * S2[i] + gamma) / - (C[i] + beta * S3[i] + gamma) + (self.A[i] + beta * roots_of_unity[i] + gamma) * + (self.B[i] + beta * 2 * roots_of_unity[i] + gamma) * + (self.C[i] + beta * 3 * roots_of_unity[i] + gamma) / + (self.A[i] + beta * self.S1[i] + gamma) / + (self.B[i] + beta * self.S2[i] + gamma) / + (self.C[i] + beta * self.S3[i] + gamma) ) assert Z.pop() == 1 - z_1 = setup.evaluations_to_point(Z) - alpha = binhash_to_f_inner(keccak256(serialize_point(z_1))) + + # Sanity-check that Z was computed correctly + for i in range(group_order): + assert ( + self.rlc(self.A[i], roots_of_unity[i]) * + self.rlc(self.B[i], 2 * roots_of_unity[i]) * + self.rlc(self.C[i], 3 * roots_of_unity[i]) + ) * Z[i] - ( + self.rlc(self.A[i], self.S1[i]) * + self.rlc(self.B[i], self.S2[i]) * + self.rlc(self.C[i], self.S3[i]) + ) * Z[(i+1) % group_order] == 0 + + z_1 = setup.commit(Z) + transcript.hash_point(z_1) print("Permutation accumulator polynomial successfully generated") + self.Z = Z + return z_1 + + def round_3(self, transcript: Transcript, setup: Setup): + group_order = self.group_order + # Compute the quotient polynomial # List of roots of unity at 4x fineness quarter_roots = get_roots_of_unity(group_order * 4) + alpha = transcript.squeeze() + transcript.alpha = alpha + # This value could be anything, it just needs to be unpredictable. Lets us # have evaluation forms at cosets to avoid zero evaluations, so we can # divide polys without the 0/0 issue - fft_offset = binhash_to_f_inner( - keccak256(keccak256(serialize_point(z_1))) - ) - - fft_expand = lambda x: fft_expand_with_offset(x, fft_offset) - expanded_evals_to_coeffs = lambda x: offset_evals_to_coeffs(x, fft_offset) + fft_cofactor = transcript.squeeze() + transcript.fft_cofactor = fft_cofactor + self.fft_cofactor = fft_cofactor - A_big = fft_expand(A) - B_big = fft_expand(B) - C_big = fft_expand(C) + A_big = self.fft_expand(self.A) + B_big = self.fft_expand(self.B) + C_big = self.fft_expand(self.C) # Z_H = X^N - 1, also in evaluation form in the coset ZH_big = [ - ((f_inner(r) * fft_offset) ** group_order - 1) + ((Scalar(r) * fft_cofactor) ** group_order - 1) for r in quarter_roots ] - QL, QR, QM, QO, QC = program.make_gate_polynomials() - QL_big, QR_big, QM_big, QO_big, QC_big, PI_big = \ - (fft_expand(x) for x in (QL, QR, QM, QO, QC, PI)) + (self.fft_expand(x) for x in (self.QL, self.QR, self.QM, self.QO, self.QC, self.PI)) - Z_big = fft_expand(Z) + Z_big = self.fft_expand(self.Z) Z_shifted_big = Z_big[4:] + Z_big[:4] - S1_big = fft_expand(S1) - S2_big = fft_expand(S2) - S3_big = fft_expand(S3) + S1_big = self.fft_expand(self.S1) + S2_big = self.fft_expand(self.S2) + S3_big = self.fft_expand(self.S3) # Equals 1 at x=1 and 0 at other roots of unity - L1_big = fft_expand([f_inner(1)] + [f_inner(0)] * (group_order - 1)) - - # Some sanity checks to make sure everything is ok up to here - for i in range(group_order): - # print('a', A[i], 'b', B[i], 'c', C[i]) - # print('ql', QL[i], 'qr', QR[i], 'qm', QM[i], 'qo', QO[i], 'qc', QC[i]) - assert ( - A[i] * QL[i] + B[i] * QR[i] + A[i] * B[i] * QM[i] + - C[i] * QO[i] + PI[i] + QC[i] == 0 - ) - - for i in range(group_order): - assert ( - (A[i] + beta * roots_of_unity[i] + gamma) * - (B[i] + beta * 2 * roots_of_unity[i] + gamma) * - (C[i] + beta * 3 * roots_of_unity[i] + gamma) - ) * Z[i] - ( - (A[i] + beta * S1[i] + gamma) * - (B[i] + beta * S2[i] + gamma) * - (C[i] + beta * S3[i] + gamma) - ) * Z[(i+1) % group_order] == 0 + L1_big = self.fft_expand([Scalar(1)] + [Scalar(0)] * (group_order - 1)) # Compute the quotient polynomial (called T(x) in the paper) # It is only possible to construct this polynomial if the following @@ -165,6 +247,8 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option # (Z - 1) * L1 = 0 # L1 = Lagrange polynomial, equal at all roots of unity except 1 + beta = transcript.beta + gamma = transcript.gamma QUOT_big = [(( A_big[i] * QL_big[i] + B_big[i] * QR_big[i] + @@ -173,9 +257,9 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option PI_big[i] + QC_big[i] ) + ( - (A_big[i] + beta * fft_offset * quarter_roots[i] + gamma) * - (B_big[i] + beta * 2 * fft_offset * quarter_roots[i] + gamma) * - (C_big[i] + beta * 3 * fft_offset * quarter_roots[i] + gamma) + (A_big[i] + beta * fft_cofactor * quarter_roots[i] + gamma) * + (B_big[i] + beta * 2 * fft_cofactor * quarter_roots[i] + gamma) * + (C_big[i] + beta * 3 * fft_cofactor * quarter_roots[i] + gamma) ) * alpha * Z_big[i] - ( (A_big[i] + beta * S1_big[i] + gamma) * (B_big[i] + beta * S2_big[i] + gamma) * @@ -184,11 +268,11 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option (Z_big[i] - 1) * L1_big[i] * alpha**2 )) / ZH_big[i] for i in range(group_order * 4)] - all_coeffs = expanded_evals_to_coeffs(QUOT_big) + all_coeffs = self.expanded_evals_to_coeffs(QUOT_big) # Sanity check: QUOT has degree < 3n assert ( - expanded_evals_to_coeffs(QUOT_big)[-group_order:] == + self.expanded_evals_to_coeffs(QUOT_big)[-group_order:] == [0] * group_order ) print("Generated the quotient polynomial") @@ -199,21 +283,37 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option T2 = f_inner_fft(all_coeffs[group_order: group_order * 2]) T3 = f_inner_fft(all_coeffs[group_order * 2: group_order * 3]) - t_lo_1 = setup.evaluations_to_point(T1) - t_mid_1 = setup.evaluations_to_point(T2) - t_hi_1 = setup.evaluations_to_point(T3) - print("Generated T1, T2, T3 polynomials") - - buf2 = serialize_point(t_lo_1)+serialize_point(t_mid_1)+serialize_point(t_hi_1) - zed = binhash_to_f_inner(keccak256(buf2)) - # Sanity check that we've computed T1, T2, T3 correctly assert ( - barycentric_eval_at_point(T1, fft_offset) + - barycentric_eval_at_point(T2, fft_offset) * fft_offset**group_order + - barycentric_eval_at_point(T3, fft_offset) * fft_offset**(group_order*2) + barycentric_eval_at_point(T1, fft_cofactor) + + barycentric_eval_at_point(T2, fft_cofactor) * fft_cofactor**group_order + + barycentric_eval_at_point(T3, fft_cofactor) * fft_cofactor**(group_order*2) ) == QUOT_big[0] + print("Generated T1, T2, T3 polynomials") + + t_lo_1 = setup.commit(T1) + transcript.hash_point(t_lo_1) + + t_mid_1 = setup.commit(T2) + transcript.hash_point(t_mid_1) + + t_hi_1 = setup.commit(T3) + transcript.hash_point(t_hi_1) + + self.T1 = T1 + self.T2 = T2 + self.T3 = T3 + + return t_lo_1, t_mid_1, t_hi_1 + + def round_4(self, transcript: Transcript): + group_order = self.group_order + + zed = transcript.squeeze() + transcript.zed = zed + self.zed = zed + # Compute the "linearization polynomial" R. This is a clever way to avoid # needing to provide evaluations of _all_ the polynomials that we are # checking an equation betweeen: instead, we can "skip" the first @@ -227,37 +327,73 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option # it has to be "linear" in the proof items, hence why we can only use each # proof item once; any further multiplicands in each term need to be # replaced with their evaluations at Z, which do still need to be provided - a_eval = barycentric_eval_at_point(A, zed) - b_eval = barycentric_eval_at_point(B, zed) - c_eval = barycentric_eval_at_point(C, zed) - s1_eval = barycentric_eval_at_point(S1, zed) - s2_eval = barycentric_eval_at_point(S2, zed) - z_shifted_eval = barycentric_eval_at_point(Z, zed * roots_of_unity[1]) + a_eval = barycentric_eval_at_point(self.A, zed) + transcript.hash_scalar(a_eval) + + b_eval = barycentric_eval_at_point(self.B, zed) + transcript.hash_scalar(b_eval) + c_eval = barycentric_eval_at_point(self.C, zed) + transcript.hash_scalar(c_eval) + + s1_eval = barycentric_eval_at_point(self.S1, zed) + transcript.hash_scalar(s1_eval) + + s2_eval = barycentric_eval_at_point(self.S2, zed) + transcript.hash_scalar(s2_eval) + + root_of_unity = get_root_of_unity(group_order) + z_shifted_eval = barycentric_eval_at_point(self.Z, zed * root_of_unity) + transcript.hash_scalar(z_shifted_eval) + + self.a_eval = a_eval + self.b_eval = b_eval + self.c_eval = c_eval + self.s1_eval = s1_eval + self.s2_eval = s2_eval + self.z_shifted_eval = z_shifted_eval + + return a_eval, b_eval, c_eval, s1_eval, s2_eval, z_shifted_eval + + def round_5(self, transcript: Transcript, setup: Setup): + group_order = self.group_order + + v = transcript.squeeze() + transcript.v = v + + zed = transcript.zed L1_ev = barycentric_eval_at_point([1] + [0] * (group_order - 1), zed) ZH_ev = zed ** group_order - 1 - PI_ev = barycentric_eval_at_point(PI, zed) + PI_ev = barycentric_eval_at_point(self.PI, zed) - T1_big = fft_expand(T1) - T2_big = fft_expand(T2) - T3_big = fft_expand(T3) + T1_big = self.fft_expand(self.T1) + T2_big = self.fft_expand(self.T2) + T3_big = self.fft_expand(self.T3) + QL_big, QR_big, QM_big, QO_big, QC_big, PI_big = \ + (self.fft_expand(x) for x in (self.QL, self.QR, self.QM, self.QO, self.QC, self.PI)) + Z_big = self.fft_expand(self.Z) + S3_big = self.fft_expand(self.S3) + + beta = transcript.beta + gamma = transcript.gamma + alpha = transcript.alpha R_big = [( - a_eval * QL_big[i] + - b_eval * QR_big[i] + - a_eval * b_eval * QM_big[i] + - c_eval * QO_big[i] + + self.a_eval * QL_big[i] + + self.b_eval * QR_big[i] + + self.a_eval * self.b_eval * QM_big[i] + + self.c_eval * QO_big[i] + PI_ev + QC_big[i] ) + ( - (a_eval + beta * zed + gamma) * - (b_eval + beta * 2 * zed + gamma) * - (c_eval + beta * 3 * zed + gamma) + (self.a_eval + beta * zed + gamma) * + (self.b_eval + beta * 2 * zed + gamma) * + (self.c_eval + beta * 3 * zed + gamma) ) * alpha * Z_big[i] - ( - (a_eval + beta * s1_eval + gamma) * - (b_eval + beta * s2_eval + gamma) * - (c_eval + beta * S3_big[i] + gamma) - ) * alpha * z_shifted_eval + ( + (self.a_eval + beta * self.s1_eval + gamma) * + (self.b_eval + beta * self.s2_eval + gamma) * + (self.c_eval + beta * S3_big[i] + gamma) + ) * alpha * self.z_shifted_eval + ( (Z_big[i] - 1) * L1_ev ) * alpha**2 - ( T1_big[i] + @@ -265,37 +401,44 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option zed ** (group_order * 2) * T3_big[i] ) * ZH_ev for i in range(4 * group_order)] - R_coeffs = expanded_evals_to_coeffs(R_big) + R_coeffs = self.expanded_evals_to_coeffs(R_big) assert R_coeffs[group_order:] == [0] * (group_order * 3) R = f_inner_fft(R_coeffs[:group_order]) - print('R_pt', setup.evaluations_to_point(R)) + print('R_pt', setup.commit(R)) assert barycentric_eval_at_point(R, zed) == 0 print("Generated linearization polynomial R") - buf3 = b''.join([ - serialize_int(x) for x in - (a_eval, b_eval, c_eval, s1_eval, s2_eval, z_shifted_eval) - ]) - v = binhash_to_f_inner(keccak256(buf3)) - # Generate proof that W(z) = 0 and that the provided evaluations of # A, B, C, S1, S2 are correct + A_big = self.fft_expand(self.A) + B_big = self.fft_expand(self.B) + C_big = self.fft_expand(self.C) + + QL_big, QR_big, QM_big, QO_big, QC_big, PI_big = \ + (self.fft_expand(x) for x in (self.QL, self.QR, self.QM, self.QO, self.QC, self.PI)) + S1_big = self.fft_expand(self.S1) + S2_big = self.fft_expand(self.S2) + S3_big = self.fft_expand(self.S3) + + roots_of_unity = get_roots_of_unity(group_order) + quarter_roots = get_roots_of_unity(group_order * 4) + W_z_big = [( R_big[i] + - v * (A_big[i] - a_eval) + - v**2 * (B_big[i] - b_eval) + - v**3 * (C_big[i] - c_eval) + - v**4 * (S1_big[i] - s1_eval) + - v**5 * (S2_big[i] - s2_eval) - ) / (fft_offset * quarter_roots[i] - zed) for i in range(group_order * 4)] - - W_z_coeffs = expanded_evals_to_coeffs(W_z_big) + v * (A_big[i] - self.a_eval) + + v**2 * (B_big[i] - self.b_eval) + + v**3 * (C_big[i] - self.c_eval) + + v**4 * (S1_big[i] - self.s1_eval) + + v**5 * (S2_big[i] - self.s2_eval) + ) / (transcript.fft_cofactor * quarter_roots[i] - zed) for i in range(group_order * 4)] + + W_z_coeffs = self.expanded_evals_to_coeffs(W_z_big) assert W_z_coeffs[group_order:] == [0] * (group_order * 3) W_z = f_inner_fft(W_z_coeffs[:group_order]) - W_z_1 = setup.evaluations_to_point(W_z) + W_z_1 = setup.commit(W_z) # Generate proof that the provided evaluation of Z(z*w) is correct. This # awkwardly different term is needed because the permutation accumulator @@ -303,18 +446,23 @@ def prove_from_witness(cls, setup: Setup, program: Program, witness: dict[Option # coordinates, and not just within one coordinate. W_zw_big = [ - (Z_big[i] - z_shifted_eval) / - (fft_offset * quarter_roots[i] - zed * roots_of_unity[1]) + (Z_big[i] - self.z_shifted_eval) / + (transcript.fft_cofactor * quarter_roots[i] - zed * roots_of_unity[1]) for i in range(group_order * 4)] - W_zw_coeffs = expanded_evals_to_coeffs(W_zw_big) + W_zw_coeffs = self.expanded_evals_to_coeffs(W_zw_big) assert W_zw_coeffs[group_order:] == [0] * (group_order * 3) W_zw = f_inner_fft(W_zw_coeffs[:group_order]) - W_zw_1 = setup.evaluations_to_point(W_zw) + W_zw_1 = setup.commit(W_zw) print("Generated final quotient witness polynomials") - return cls( - a_1, b_1, c_1, z_1, t_lo_1, t_mid_1, t_hi_1, - a_eval, b_eval, c_eval, s1_eval, s2_eval, z_shifted_eval, - W_z_1, W_zw_1, - ) + return W_z_1, W_zw_1 + + def fft_expand(self, x): + return fft_expand_with_offset(x, self.fft_cofactor) + + def expanded_evals_to_coeffs(self, x): + return offset_evals_to_coeffs(x, self.fft_cofactor) + + def rlc(self, term_1, term_2): + return term_1 + self.beta * term_2 + self.gamma diff --git a/setup.py b/setup.py index a037ca4..c24cfdd 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,16 @@ from utils import * import py_ecc.bn128 as b from typing import NewType +from curve import ec_lincomb, Scalar + +G1Point = NewType('G1Point', tuple[b.FQ, b.FQ]) +G2Point = NewType('G2Point', tuple[b.FQ2, b.FQ2]) # Recover the trusted setup from a file in the format used in # https://github.com/iden3/snarkjs#7-prepare-phase-2 SETUP_FILE_G1_STARTPOS = 80 SETUP_FILE_POWERS_POS = 60 - -G1Point = NewType('G1Point', tuple[b.FQ, b.FQ]) -G2Point = NewType('G2Point', tuple[b.FQ2, b.FQ2]) +Commitment = NewType('Commitment', G1Point) class Setup(object): # ([1]₁, [x]₁, ..., [x^{d-1}]₁) @@ -62,12 +64,10 @@ def from_file(cls, filename): # print("X^1 points checked consistent") return cls(G1_side, X2) - # Encodes the KZG commitment to the given polynomial coeffs - def coeffs_to_point(self, coeffs): + # Encodes the KZG commitment that evaluates to the given values in the group + def commit(self, values) -> Commitment: + # inverse FFT from Lagrange basis to monomial basis + coeffs = f_inner_fft(values, inv=True) if len(coeffs) > len(self.G1_side): - raise Exception("Not enough powers in setup") + raise Exception("Not enough powers in setup") return ec_lincomb([(s, x) for s, x in zip(self.G1_side, coeffs)]) - - # Encodes the KZG commitment that evaluates to the given values in the group - def evaluations_to_point(self, evals): - return self.coeffs_to_point(f_inner_fft(evals, inv=True)) diff --git a/test.py b/test.py index eca2e4d..a37f35a 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,6 @@ from compiler.program import Program from setup import Setup -from prover import Proof +from prover import Prover from verifier import VerificationKey import json from test.mini_poseidon import rc, mds, poseidon_hash @@ -58,7 +58,7 @@ def prover_test(setup): print("Beginning prover test") program = Program(['e public', 'c <== a * b', 'e <== c * d'], 8) assignments = {'a': 3, 'b': 4, 'c': 12, 'd': 5, 'e': 60} - return Proof.prove_from_witness(setup, program, assignments) + return Prover().prove(setup, program, assignments) print("Prover test success") def verifier_test(setup, proof): @@ -98,7 +98,7 @@ def factorization_test(setup): 'pb3': 1, 'pb2': 1, 'pb1': 0, 'pb0': 1, 'qb3': 0, 'qb2': 1, 'qb1': 1, 'qb0': 1, }) - proof = Proof.prove_from_witness(setup, program, assignments) + proof = Prover().prove(setup, program, assignments) print("Generated proof") assert vk.verify_proof(16, proof, public, optimized=True) print("Factorization test success!") @@ -139,7 +139,7 @@ def poseidon_test(setup): assignments = program.fill_variable_assignments({'L0': 1, 'M0': 2}) vk = VerificationKey.make_verification_key(program, setup) print("Generated verification key") - proof = Proof.prove_from_witness(setup, program, assignments) + proof = Prover().prove(setup, program, assignments) print("Generated proof") assert vk.verify_proof(1024, proof, [1, 2, expected_value]) print("Verified proof!") diff --git a/transcript.py b/transcript.py new file mode 100644 index 0000000..a7afa56 --- /dev/null +++ b/transcript.py @@ -0,0 +1,28 @@ +from Crypto.Hash import keccak +from typing import Optional, Union +from curve import Scalar +from setup import G1Point + +class Transcript: + beta: Optional[Scalar] = None + gamma: Optional[Scalar] = None + alpha: Optional[Scalar] = None + fft_cofactor: Optional[Scalar] = None + zed: Optional[Scalar] = None + v: Optional[Scalar] = None + + def __init__(self): + self.state = keccak.new(digest_bits=256) + + def hash_scalar(self, scalar: Scalar): + string = scalar.n.to_bytes(32, 'big') + self.state.update(string) + + def hash_point(self, point: G1Point): + string = point[0].n.to_bytes(32, 'big') + point[1].n.to_bytes(32, 'big') + self.state.update(string) + + def squeeze(self): + digest = self.state.digest() + self.state = keccak.new(digest_bits=256).update(digest) + return Scalar(int.from_bytes(digest, 'big')) diff --git a/verifier.py b/verifier.py index 15920c7..9c8da6f 100644 --- a/verifier.py +++ b/verifier.py @@ -2,10 +2,11 @@ from utils import * from dataclasses import dataclass from setup import G1Point, G2Point -from prover import Proof +from prover import Prover from compiler.program import Program from compiler.utils import Column from setup import Setup +from transcript import Transcript @dataclass class VerificationKey: @@ -37,14 +38,14 @@ def make_verification_key(cls, program: Program, setup: Setup): L, R, M, O, C = program.make_gate_polynomials() S = program.make_s_polynomials() return cls( - setup.evaluations_to_point(M), - setup.evaluations_to_point(L), - setup.evaluations_to_point(R), - setup.evaluations_to_point(O), - setup.evaluations_to_point(C), - setup.evaluations_to_point(S[Column.LEFT]), - setup.evaluations_to_point(S[Column.RIGHT]), - setup.evaluations_to_point(S[Column.OUTPUT]), + setup.commit(M), + setup.commit(L), + setup.commit(R), + setup.commit(O), + setup.commit(C), + setup.commit(S[Column.LEFT]), + setup.commit(S[Column.RIGHT]), + setup.commit(S[Column.OUTPUT]), setup.X2, get_root_of_unity(program.group_order) ) @@ -56,7 +57,7 @@ def make_verification_key(cls, program: Program, setup: Setup): def _verify_inner( self, group_order: int, - proof: Proof, + proof, PI_ev: f_inner, v: f_inner, zed: f_inner, @@ -69,35 +70,35 @@ def _verify_inner( L1_ev = ZH_ev / (group_order * (zed - 1)) R_pt = ec_lincomb([ - (self.Qm, proof.a_eval * proof.b_eval), - (self.Ql, proof.a_eval), - (self.Qr, proof.b_eval), - (self.Qo, proof.c_eval), + (self.Qm, proof['a_eval'] * proof['b_eval']), + (self.Ql, proof['a_eval']), + (self.Qr, proof['b_eval']), + (self.Qo, proof['c_eval']), (b.G1, PI_ev), (self.Qc, 1), - (proof.z_1, ( - (proof.a_eval + beta * zed + gamma) * - (proof.b_eval + beta * 2 * zed + gamma) * - (proof.c_eval + beta * 3 * zed + gamma) * + (proof['z_1'], ( + (proof['a_eval'] + beta * zed + gamma) * + (proof['b_eval'] + beta * 2 * zed + gamma) * + (proof['c_eval'] + beta * 3 * zed + gamma) * alpha )), (self.S3, ( - -(proof.a_eval + beta * proof.s1_eval + gamma) * - (proof.b_eval + beta * proof.s2_eval + gamma) * + -(proof['a_eval'] + beta * proof['s1_eval'] + gamma) * + (proof['b_eval'] + beta * proof['s2_eval'] + gamma) * beta * - alpha * proof.z_shifted_eval + alpha * proof['z_shifted_eval'] )), (b.G1, ( - -(proof.a_eval + beta * proof.s1_eval + gamma) * - (proof.b_eval + beta * proof.s2_eval + gamma) * - (proof.c_eval + gamma) * - alpha * proof.z_shifted_eval + -(proof['a_eval'] + beta * proof['s1_eval'] + gamma) * + (proof['b_eval'] + beta * proof['s2_eval'] + gamma) * + (proof['c_eval'] + gamma) * + alpha * proof['z_shifted_eval'] )), - (proof.z_1, L1_ev * alpha ** 2), + (proof['z_1'], L1_ev * alpha ** 2), (b.G1, -L1_ev * alpha ** 2), - (proof.t_lo_1, -ZH_ev), - (proof.t_mid_1, -ZH_ev * zed**group_order), - (proof.t_hi_1, -ZH_ev * zed**(group_order*2)), + (proof['t_lo_1'], -ZH_ev), + (proof['t_mid_1'], -ZH_ev * zed**group_order), + (proof['t_hi_1'], -ZH_ev * zed**(group_order*2)), ]) print('verifier R_pt', R_pt) @@ -108,20 +109,20 @@ def _verify_inner( b.G2, ec_lincomb([ (R_pt, 1), - (proof.a_1, v), - (b.G1, -v * proof.a_eval), - (proof.b_1, v**2), - (b.G1, -v**2 * proof.b_eval), - (proof.c_1, v**3), - (b.G1, -v**3 * proof.c_eval), + (proof['a_1'], v), + (b.G1, -v * proof['a_eval']), + (proof['b_1'], v**2), + (b.G1, -v**2 * proof['b_eval']), + (proof['c_1'], v**3), + (b.G1, -v**3 * proof['c_eval']), (self.S1, v**4), - (b.G1, -v**4 * proof.s1_eval), + (b.G1, -v**4 * proof['s1_eval']), (self.S2, v**5), - (b.G1, -v**5 * proof.s2_eval), + (b.G1, -v**5 * proof['s2_eval']), ]) ) == b.pairing( b.add(self.X_2, ec_mul(b.G2, -zed)), - proof.W_z_1 + proof['W_z_1'] ) print("done check 1") @@ -129,12 +130,12 @@ def _verify_inner( assert b.pairing( b.G2, ec_lincomb([ - (proof.z_1, 1), - (b.G1, -proof.z_shifted_eval) + (proof['z_1'], 1), + (b.G1, -proof['z_shifted_eval']) ]) ) == b.pairing( b.add(self.X_2, ec_mul(b.G2, -zed * root_of_unity)), - proof.W_zw_1 + proof['W_zw_1'] ) print("done check 2") return True @@ -146,7 +147,7 @@ def _verify_inner( def _optimized_verify_inner( self, group_order: int, - proof: Proof, + proof, PI_ev: f_inner, v: f_inner, u: f_inner, @@ -165,49 +166,49 @@ def _optimized_verify_inner( r0 = ( PI_ev - L1_ev * alpha ** 2 - ( alpha * - (proof.a_eval + beta * proof.s1_eval + gamma) * - (proof.b_eval + beta * proof.s2_eval + gamma) * - (proof.c_eval + gamma) * - proof.z_shifted_eval + (proof['a_eval'] + beta * proof['s1_eval'] + gamma) * + (proof['b_eval'] + beta * proof['s2_eval'] + gamma) * + (proof['c_eval'] + gamma) * + proof['z_shifted_eval'] ) ) # D = (R - r0) + u * Z D_pt = ec_lincomb([ - (self.Qm, proof.a_eval * proof.b_eval), - (self.Ql, proof.a_eval), - (self.Qr, proof.b_eval), - (self.Qo, proof.c_eval), + (self.Qm, proof['a_eval'] * proof['b_eval']), + (self.Ql, proof['a_eval']), + (self.Qr, proof['b_eval']), + (self.Qo, proof['c_eval']), (self.Qc, 1), - (proof.z_1, ( - (proof.a_eval + beta * zed + gamma) * - (proof.b_eval + beta * 2 * zed + gamma) * - (proof.c_eval + beta * 3 * zed + gamma) * alpha + + (proof['z_1'], ( + (proof['a_eval'] + beta * zed + gamma) * + (proof['b_eval'] + beta * 2 * zed + gamma) * + (proof['c_eval'] + beta * 3 * zed + gamma) * alpha + L1_ev * alpha ** 2 + u )), (self.S3, ( - -(proof.a_eval + beta * proof.s1_eval + gamma) * - (proof.b_eval + beta * proof.s2_eval + gamma) * - alpha * beta * proof.z_shifted_eval + -(proof['a_eval'] + beta * proof['s1_eval'] + gamma) * + (proof['b_eval'] + beta * proof['s2_eval'] + gamma) * + alpha * beta * proof['z_shifted_eval'] )), - (proof.t_lo_1, -ZH_ev), - (proof.t_mid_1, -ZH_ev * zed**group_order), - (proof.t_hi_1, -ZH_ev * zed**(group_order*2)), + (proof['t_lo_1'], -ZH_ev), + (proof['t_mid_1'], -ZH_ev * zed**group_order), + (proof['t_hi_1'], -ZH_ev * zed**(group_order*2)), ]) F_pt = ec_lincomb([ (D_pt, 1), - (proof.a_1, v), - (proof.b_1, v**2), - (proof.c_1, v**3), + (proof['a_1'], v), + (proof['b_1'], v**2), + (proof['c_1'], v**3), (self.S1, v**4), (self.S2, v**5), ]) E_pt = ec_mul(b.G1, ( - -r0 + v * proof.a_eval + v**2 * proof.b_eval + v**3 * proof.c_eval + - v**4 * proof.s1_eval + v**5 * proof.s2_eval + u * proof.z_shifted_eval + -r0 + v * proof['a_eval'] + v**2 * proof['b_eval'] + v**3 * proof['c_eval'] + + v**4 * proof['s1_eval'] + v**5 * proof['s2_eval'] + u * proof['z_shifted_eval'] )) # What's going on here is a clever re-arrangement of terms to check @@ -226,11 +227,11 @@ def _optimized_verify_inner( # so at this point we can take a random linear combination of the two # checks, and verify it with only one pairing. assert b.pairing(self.X_2, ec_lincomb([ - (proof.W_z_1, 1), - (proof.W_zw_1, u) + (proof['W_z_1'], 1), + (proof['W_zw_1'], u) ])) == b.pairing(b.G2, ec_lincomb([ - (proof.W_z_1, zed), - (proof.W_zw_1, u * zed * root_of_unity), + (proof['W_z_1'], zed), + (proof['W_zw_1'], u * zed * root_of_unity), (F_pt, 1), (E_pt, -1) ])) @@ -238,27 +239,42 @@ def _optimized_verify_inner( print("done combined check") return True - def verify_proof(self, group_order: int, proof: Proof, public=[], optimized=True) -> bool: + def verify_proof(self, group_order: int, proof, public=[], optimized=True) -> bool: # Compute challenges (should be same as those computed by prover) - buf = serialize_point(proof.a_1) + serialize_point(proof.b_1) + serialize_point(proof.c_1) + transcript = Transcript() + transcript.hash_point(proof['a_1']) + transcript.hash_point(proof['b_1']) + transcript.hash_point(proof['c_1']) - beta = binhash_to_f_inner(keccak256(buf)) - gamma = binhash_to_f_inner(keccak256(keccak256(buf))) + beta = transcript.squeeze() + gamma = transcript.squeeze() - alpha = binhash_to_f_inner(keccak256(serialize_point(proof.z_1))) + transcript.hash_point(proof['z_1']) + alpha = transcript.squeeze() - buf2 = serialize_point(proof.t_lo_1)+serialize_point(proof.t_mid_1)+serialize_point(proof.t_hi_1) - zed = binhash_to_f_inner(keccak256(buf2)) + fft_cofactor = transcript.squeeze() - buf3 = b''.join([ - serialize_int(x) for x in - (proof.a_eval, proof.b_eval, proof.c_eval, proof.s1_eval, proof.s2_eval, proof.z_shifted_eval) - ]) - v = binhash_to_f_inner(keccak256(buf3)) + transcript.hash_point(proof['t_lo_1']) + transcript.hash_point(proof['t_mid_1']) + transcript.hash_point(proof['t_hi_1']) + + zed = transcript.squeeze() + + transcript.hash_scalar(proof['a_eval']) + transcript.hash_scalar(proof['b_eval']) + transcript.hash_scalar(proof['c_eval']) + transcript.hash_scalar(proof['s1_eval']) + transcript.hash_scalar(proof['s2_eval']) + transcript.hash_scalar(proof['z_shifted_eval']) + + v = transcript.squeeze() + + transcript.hash_point(proof['W_z_1']) + transcript.hash_point(proof['W_zw_1']) # Does not need to be standardized, only needs to be unpredictable - u = binhash_to_f_inner(keccak256(buf + buf2 + buf3)) + u = transcript.squeeze() PI_ev = barycentric_eval_at_point( [f_inner(-x) for x in public] +