diff --git a/zk-sdk/src/lib.rs b/zk-sdk/src/lib.rs index 8119c35acfc30c..2c7e591e02a2a3 100644 --- a/zk-sdk/src/lib.rs +++ b/zk-sdk/src/lib.rs @@ -22,6 +22,7 @@ #[cfg(not(target_os = "solana"))] pub mod encryption; pub mod errors; +mod range_proof; #[cfg(not(target_os = "solana"))] mod sigma_proofs; #[cfg(not(target_os = "solana"))] diff --git a/zk-sdk/src/range_proof/errors.rs b/zk-sdk/src/range_proof/errors.rs new file mode 100644 index 00000000000000..f3c304fab2e6d7 --- /dev/null +++ b/zk-sdk/src/range_proof/errors.rs @@ -0,0 +1,46 @@ +//! Errors related to proving and verifying range proofs. +use {crate::errors::TranscriptError, thiserror::Error}; + +#[cfg(not(target_os = "solana"))] +#[derive(Error, Clone, Debug, Eq, PartialEq)] +pub enum RangeProofGenerationError { + #[error("maximum generator length exceeded")] + MaximumGeneratorLengthExceeded, + #[error("amounts, commitments, openings, or bit lengths vectors have different lengths")] + VectorLengthMismatch, + #[error("invalid bit size")] + InvalidBitSize, + #[error("insufficient generators for the proof")] + GeneratorLengthMismatch, + #[error("inner product length mismatch")] + InnerProductLengthMismatch, +} + +#[derive(Error, Clone, Debug, Eq, PartialEq)] +pub enum RangeProofVerificationError { + #[error("required algebraic relation does not hold")] + AlgebraicRelation, + #[error("malformed proof")] + Deserialization, + #[error("multiscalar multiplication failed")] + MultiscalarMul, + #[error("transcript failed to produce a challenge")] + Transcript(#[from] TranscriptError), + #[error( + "attempted to verify range proof with a non-power-of-two bit size or bit size is too big" + )] + InvalidBitSize, + #[error("insufficient generators for the proof")] + InvalidGeneratorsLength, + #[error("maximum generator length exceeded")] + MaximumGeneratorLengthExceeded, + #[error("commitments and bit lengths vectors have different lengths")] + VectorLengthMismatch, +} + +#[cfg(not(target_os = "solana"))] +#[derive(Error, Clone, Debug, Eq, PartialEq)] +pub enum RangeProofGeneratorError { + #[error("maximum generator length exceeded")] + MaximumGeneratorLengthExceeded, +} diff --git a/zk-sdk/src/range_proof/generators.rs b/zk-sdk/src/range_proof/generators.rs new file mode 100644 index 00000000000000..8a3ccf5a29e842 --- /dev/null +++ b/zk-sdk/src/range_proof/generators.rs @@ -0,0 +1,155 @@ +use { + crate::range_proof::errors::RangeProofGeneratorError, + curve25519_dalek::{ + digest::{ExtendableOutput, Update, XofReader}, + ristretto::RistrettoPoint, + }, + sha3::{Sha3XofReader, Shake256}, +}; + +const MAX_GENERATOR_LENGTH: usize = u32::MAX as usize; + +/// Generators for Pedersen vector commitments that are used for inner-product proofs. +struct GeneratorsChain { + reader: Sha3XofReader, +} + +impl GeneratorsChain { + /// Creates a chain of generators, determined by the hash of `label`. + fn new(label: &[u8]) -> Self { + let mut shake = Shake256::default(); + shake.update(b"GeneratorsChain"); + shake.update(label); + + GeneratorsChain { + reader: shake.finalize_xof(), + } + } + + /// Advances the reader n times, squeezing and discarding + /// the result. + fn fast_forward(mut self, n: usize) -> Self { + for _ in 0..n { + let mut buf = [0u8; 64]; + self.reader.read(&mut buf); + } + self + } +} + +impl Default for GeneratorsChain { + fn default() -> Self { + Self::new(&[]) + } +} + +impl Iterator for GeneratorsChain { + type Item = RistrettoPoint; + + fn next(&mut self) -> Option { + let mut uniform_bytes = [0u8; 64]; + self.reader.read(&mut uniform_bytes); + + Some(RistrettoPoint::from_uniform_bytes(&uniform_bytes)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} + +#[allow(non_snake_case)] +#[derive(Clone)] +pub struct RangeProofGens { + /// The maximum number of usable generators. + pub gens_capacity: usize, + /// Precomputed \\(\mathbf G\\) generators. + G_vec: Vec, + /// Precomputed \\(\mathbf H\\) generators. + H_vec: Vec, +} + +impl RangeProofGens { + pub fn new(gens_capacity: usize) -> Result { + let mut gens = RangeProofGens { + gens_capacity: 0, + G_vec: Vec::new(), + H_vec: Vec::new(), + }; + gens.increase_capacity(gens_capacity)?; + Ok(gens) + } + + /// Increases the generators' capacity to the amount specified. + /// If less than or equal to the current capacity, does nothing. + pub fn increase_capacity( + &mut self, + new_capacity: usize, + ) -> Result<(), RangeProofGeneratorError> { + if self.gens_capacity >= new_capacity { + return Ok(()); + } + + if new_capacity > MAX_GENERATOR_LENGTH { + return Err(RangeProofGeneratorError::MaximumGeneratorLengthExceeded); + } + + self.G_vec.extend( + &mut GeneratorsChain::new(&[b'G']) + .fast_forward(self.gens_capacity) + .take(new_capacity - self.gens_capacity), + ); + + self.H_vec.extend( + &mut GeneratorsChain::new(&[b'H']) + .fast_forward(self.gens_capacity) + .take(new_capacity - self.gens_capacity), + ); + + self.gens_capacity = new_capacity; + Ok(()) + } + + #[allow(non_snake_case)] + pub(crate) fn G(&self, n: usize) -> impl Iterator { + GensIter { + array: &self.G_vec, + n, + gen_idx: 0, + } + } + + #[allow(non_snake_case)] + pub(crate) fn H(&self, n: usize) -> impl Iterator { + GensIter { + array: &self.H_vec, + n, + gen_idx: 0, + } + } +} + +struct GensIter<'a> { + array: &'a Vec, + n: usize, + gen_idx: usize, +} + +impl<'a> Iterator for GensIter<'a> { + type Item = &'a RistrettoPoint; + + fn next(&mut self) -> Option { + if self.gen_idx >= self.n { + None + } else { + let cur_gen = self.gen_idx; + self.gen_idx += 1; + Some(&self.array[cur_gen]) + } + } + + fn size_hint(&self) -> (usize, Option) { + let size = self.n - self.gen_idx; + (size, Some(size)) + } +} diff --git a/zk-sdk/src/range_proof/inner_product.rs b/zk-sdk/src/range_proof/inner_product.rs new file mode 100644 index 00000000000000..4dcef2eafed772 --- /dev/null +++ b/zk-sdk/src/range_proof/inner_product.rs @@ -0,0 +1,506 @@ +use { + crate::{ + range_proof::{ + errors::{RangeProofGenerationError, RangeProofVerificationError}, + util, + }, + transcript::TranscriptProtocol, + }, + core::iter, + curve25519_dalek::{ + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + traits::{MultiscalarMul, VartimeMultiscalarMul}, + }, + merlin::Transcript, + std::borrow::Borrow, +}; + +#[allow(non_snake_case)] +#[derive(Clone)] +pub struct InnerProductProof { + pub L_vec: Vec, // 32 * log(bit_length) + pub R_vec: Vec, // 32 * log(bit_length) + pub a: Scalar, // 32 bytes + pub b: Scalar, // 32 bytes +} + +#[allow(non_snake_case)] +impl InnerProductProof { + /// Create an inner-product proof. + /// + /// The proof is created with respect to the bases \\(G\\), \\(H'\\), + /// where \\(H'\_i = H\_i \cdot \texttt{Hprime\\_factors}\_i\\). + /// + /// The `verifier` is passed in as a parameter so that the + /// challenges depend on the *entire* transcript (including parent + /// protocols). + /// + /// The lengths of the vectors must all be the same, and must all be + /// a power of 2. + #[allow(clippy::too_many_arguments)] + pub fn new( + Q: &RistrettoPoint, + G_factors: &[Scalar], + H_factors: &[Scalar], + mut G_vec: Vec, + mut H_vec: Vec, + mut a_vec: Vec, + mut b_vec: Vec, + transcript: &mut Transcript, + ) -> Result { + // Create slices G, H, a, b backed by their respective + // vectors. This lets us reslice as we compress the lengths + // of the vectors in the main loop below. + let mut G = &mut G_vec[..]; + let mut H = &mut H_vec[..]; + let mut a = &mut a_vec[..]; + let mut b = &mut b_vec[..]; + + let mut n = G.len(); + + // All of the input vectors must have the same length. + if G.len() != n + || H.len() != n + || a.len() != n + || b.len() != n + || G_factors.len() != n + || H_factors.len() != n + { + return Err(RangeProofGenerationError::GeneratorLengthMismatch); + } + + // All of the input vectors must have a length that is a power of two. + if !n.is_power_of_two() { + return Err(RangeProofGenerationError::InvalidBitSize); + } + + transcript.innerproduct_domain_separator(n as u64); + + let lg_n = n.next_power_of_two().trailing_zeros() as usize; + let mut L_vec = Vec::with_capacity(lg_n); + let mut R_vec = Vec::with_capacity(lg_n); + + // If it's the first iteration, unroll the Hprime = H*y_inv scalar mults + // into multiscalar muls, for performance. + if n != 1 { + n = n.checked_div(2).unwrap(); + let (a_L, a_R) = a.split_at_mut(n); + let (b_L, b_R) = b.split_at_mut(n); + let (G_L, G_R) = G.split_at_mut(n); + let (H_L, H_R) = H.split_at_mut(n); + + let c_L = util::inner_product(a_L, b_R) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + let c_R = util::inner_product(a_R, b_L) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + + let L = RistrettoPoint::multiscalar_mul( + a_L.iter() + // `n` was previously divided in half and therefore, it cannot overflow. + .zip(G_factors[n..n.checked_mul(2).unwrap()].iter()) + .map(|(a_L_i, g)| a_L_i * g) + .chain( + b_R.iter() + .zip(H_factors[0..n].iter()) + .map(|(b_R_i, h)| b_R_i * h), + ) + .chain(iter::once(c_L)), + G_R.iter().chain(H_L.iter()).chain(iter::once(Q)), + ) + .compress(); + + let R = RistrettoPoint::multiscalar_mul( + a_R.iter() + .zip(G_factors[0..n].iter()) + .map(|(a_R_i, g)| a_R_i * g) + .chain( + b_L.iter() + .zip(H_factors[n..n.checked_mul(2).unwrap()].iter()) + .map(|(b_L_i, h)| b_L_i * h), + ) + .chain(iter::once(c_R)), + G_L.iter().chain(H_R.iter()).chain(iter::once(Q)), + ) + .compress(); + + L_vec.push(L); + R_vec.push(R); + + transcript.append_point(b"L", &L); + transcript.append_point(b"R", &R); + + let u = transcript.challenge_scalar(b"u"); + let u_inv = u.invert(); + + for i in 0..n { + a_L[i] = a_L[i] * u + u_inv * a_R[i]; + b_L[i] = b_L[i] * u_inv + u * b_R[i]; + G_L[i] = RistrettoPoint::multiscalar_mul( + &[ + u_inv * G_factors[i], + u * G_factors[n.checked_add(i).unwrap()], + ], + &[G_L[i], G_R[i]], + ); + H_L[i] = RistrettoPoint::multiscalar_mul( + &[ + u * H_factors[i], + u_inv * H_factors[n.checked_add(i).unwrap()], + ], + &[H_L[i], H_R[i]], + ) + } + + a = a_L; + b = b_L; + G = G_L; + H = H_L; + } + + while n != 1 { + n = n.checked_div(2).unwrap(); + let (a_L, a_R) = a.split_at_mut(n); + let (b_L, b_R) = b.split_at_mut(n); + let (G_L, G_R) = G.split_at_mut(n); + let (H_L, H_R) = H.split_at_mut(n); + + let c_L = util::inner_product(a_L, b_R) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + let c_R = util::inner_product(a_R, b_L) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + + let L = RistrettoPoint::multiscalar_mul( + a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)), + G_R.iter().chain(H_L.iter()).chain(iter::once(Q)), + ) + .compress(); + + let R = RistrettoPoint::multiscalar_mul( + a_R.iter().chain(b_L.iter()).chain(iter::once(&c_R)), + G_L.iter().chain(H_R.iter()).chain(iter::once(Q)), + ) + .compress(); + + L_vec.push(L); + R_vec.push(R); + + transcript.append_point(b"L", &L); + transcript.append_point(b"R", &R); + + let u = transcript.challenge_scalar(b"u"); + let u_inv = u.invert(); + + for i in 0..n { + a_L[i] = a_L[i] * u + u_inv * a_R[i]; + b_L[i] = b_L[i] * u_inv + u * b_R[i]; + G_L[i] = RistrettoPoint::multiscalar_mul(&[u_inv, u], &[G_L[i], G_R[i]]); + H_L[i] = RistrettoPoint::multiscalar_mul(&[u, u_inv], &[H_L[i], H_R[i]]); + } + + a = a_L; + b = b_L; + G = G_L; + H = H_L; + } + + Ok(InnerProductProof { + L_vec, + R_vec, + a: a[0], + b: b[0], + }) + } + + /// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and + /// \\([s\_{i}]\\) for combined multiscalar multiplication in a parent protocol. See [inner + /// product protocol notes](index.html#verification-equation) for details. The verifier must + /// provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner + /// product proof. + #[allow(clippy::type_complexity)] + pub(crate) fn verification_scalars( + &self, + n: usize, + transcript: &mut Transcript, + ) -> Result<(Vec, Vec, Vec), RangeProofVerificationError> { + let lg_n = self.L_vec.len(); + if lg_n == 0 || lg_n >= 32 { + // 4 billion multiplications should be enough for anyone + // and this check prevents overflow in 1<( + &self, + n: usize, + G_factors: IG, + H_factors: IH, + P: &RistrettoPoint, + Q: &RistrettoPoint, + G: &[RistrettoPoint], + H: &[RistrettoPoint], + transcript: &mut Transcript, + ) -> Result<(), RangeProofVerificationError> + where + IG: IntoIterator, + IG::Item: Borrow, + IH: IntoIterator, + IH::Item: Borrow, + { + let (u_sq, u_inv_sq, s) = self.verification_scalars(n, transcript)?; + + let g_times_a_times_s = G_factors + .into_iter() + .zip(s.iter()) + .map(|(g_i, s_i)| (self.a * s_i) * g_i.borrow()) + .take(G.len()); + + // 1/s[i] is s[!i], and !i runs from n-1 to 0 as i runs from 0 to n-1 + let inv_s = s.iter().rev(); + + let h_times_b_div_s = H_factors + .into_iter() + .zip(inv_s) + .map(|(h_i, s_i_inv)| (self.b * s_i_inv) * h_i.borrow()); + + let neg_u_sq = u_sq.iter().map(|ui| -ui); + let neg_u_inv_sq = u_inv_sq.iter().map(|ui| -ui); + + let Ls = self + .L_vec + .iter() + .map(|p| { + p.decompress() + .ok_or(RangeProofVerificationError::Deserialization) + }) + .collect::, _>>()?; + + let Rs = self + .R_vec + .iter() + .map(|p| { + p.decompress() + .ok_or(RangeProofVerificationError::Deserialization) + }) + .collect::, _>>()?; + + let expect_P = RistrettoPoint::vartime_multiscalar_mul( + iter::once(self.a * self.b) + .chain(g_times_a_times_s) + .chain(h_times_b_div_s) + .chain(neg_u_sq) + .chain(neg_u_inv_sq), + iter::once(Q) + .chain(G.iter()) + .chain(H.iter()) + .chain(Ls.iter()) + .chain(Rs.iter()), + ); + + if expect_P == *P { + Ok(()) + } else { + Err(RangeProofVerificationError::AlgebraicRelation) + } + } + + /// Returns the size in bytes required to serialize the inner + /// product proof. + /// + /// For vectors of length `n` the proof size is + /// \\(32 \cdot (2\lg n+2)\\) bytes. + pub fn serialized_size(&self) -> usize { + (self.L_vec.len() * 2 + 2) * 32 + } + + /// Serializes the proof into a byte array of \\(2n+2\\) 32-byte elements. + /// The layout of the inner product proof is: + /// * \\(n\\) pairs of compressed Ristretto points \\(L_0, R_0 \dots, L_{n-1}, R_{n-1}\\), + /// * two scalars \\(a, b\\). + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::with_capacity(self.serialized_size()); + for (l, r) in self.L_vec.iter().zip(self.R_vec.iter()) { + buf.extend_from_slice(l.as_bytes()); + buf.extend_from_slice(r.as_bytes()); + } + buf.extend_from_slice(self.a.as_bytes()); + buf.extend_from_slice(self.b.as_bytes()); + buf + } + + /// Deserializes the proof from a byte slice. + /// Returns an error in the following cases: + /// * the slice does not have \\(2n+2\\) 32-byte elements, + /// * \\(n\\) is larger or equal to 32 (proof is too big), + /// * any of \\(2n\\) points are not valid compressed Ristretto points, + /// * any of 2 scalars are not canonical scalars modulo Ristretto group order. + pub fn from_bytes(slice: &[u8]) -> Result { + let b = slice.len(); + if b % 32 != 0 { + return Err(RangeProofVerificationError::Deserialization); + } + let num_elements = b / 32; + if num_elements < 2 { + return Err(RangeProofVerificationError::Deserialization); + } + if (num_elements - 2) % 2 != 0 { + return Err(RangeProofVerificationError::Deserialization); + } + let lg_n = (num_elements - 2) / 2; + if lg_n >= 32 { + return Err(RangeProofVerificationError::Deserialization); + } + + let mut L_vec: Vec = Vec::with_capacity(lg_n); + let mut R_vec: Vec = Vec::with_capacity(lg_n); + for i in 0..lg_n { + let pos = 2 * i * 32; + L_vec.push(CompressedRistretto(util::read32(&slice[pos..]))); + R_vec.push(CompressedRistretto(util::read32(&slice[pos + 32..]))); + } + + let pos = 2 * lg_n * 32; + let a = Scalar::from_canonical_bytes(util::read32(&slice[pos..])) + .ok_or(RangeProofVerificationError::Deserialization)?; + let b = Scalar::from_canonical_bytes(util::read32(&slice[pos + 32..])) + .ok_or(RangeProofVerificationError::Deserialization)?; + + Ok(InnerProductProof { L_vec, R_vec, a, b }) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, crate::range_proof::generators::RangeProofGens, rand::rngs::OsRng, sha3::Sha3_512, + }; + + #[test] + #[allow(non_snake_case)] + fn test_basic_correctness() { + let n = 32; + + let bp_gens = RangeProofGens::new(n).unwrap(); + let G: Vec = bp_gens.G(n).cloned().collect(); + let H: Vec = bp_gens.H(n).cloned().collect(); + + let Q = RistrettoPoint::hash_from_bytes::(b"test point"); + + let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect(); + let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect(); + let c = util::inner_product(&a, &b).unwrap(); + + let G_factors: Vec = iter::repeat(Scalar::one()).take(n).collect(); + + let y_inv = Scalar::random(&mut OsRng); + let H_factors: Vec = util::exp_iter(y_inv).take(n).collect(); + + // P would be determined upstream, but we need a correct P to check the proof. + // + // To generate P = + + Q, compute + // P = + + Q, + // where b' = b \circ y^(-n) + let b_prime = b.iter().zip(util::exp_iter(y_inv)).map(|(bi, yi)| bi * yi); + // a.iter() has Item=&Scalar, need Item=Scalar to chain with b_prime + let a_prime = a.iter().cloned(); + + let P = RistrettoPoint::vartime_multiscalar_mul( + a_prime.chain(b_prime).chain(iter::once(c)), + G.iter().chain(H.iter()).chain(iter::once(&Q)), + ); + + let mut prover_transcript = Transcript::new(b"innerproducttest"); + let mut verifier_transcript = Transcript::new(b"innerproducttest"); + + let proof = InnerProductProof::new( + &Q, + &G_factors, + &H_factors, + G.clone(), + H.clone(), + a.clone(), + b.clone(), + &mut prover_transcript, + ) + .unwrap(); + + assert!(proof + .verify( + n, + iter::repeat(Scalar::one()).take(n), + util::exp_iter(y_inv).take(n), + &P, + &Q, + &G, + &H, + &mut verifier_transcript, + ) + .is_ok()); + + let proof = InnerProductProof::from_bytes(proof.to_bytes().as_slice()).unwrap(); + let mut verifier_transcript = Transcript::new(b"innerproducttest"); + assert!(proof + .verify( + n, + iter::repeat(Scalar::one()).take(n), + util::exp_iter(y_inv).take(n), + &P, + &Q, + &G, + &H, + &mut verifier_transcript, + ) + .is_ok()); + } +} diff --git a/zk-sdk/src/range_proof/mod.rs b/zk-sdk/src/range_proof/mod.rs new file mode 100644 index 00000000000000..1808096cadd870 --- /dev/null +++ b/zk-sdk/src/range_proof/mod.rs @@ -0,0 +1,479 @@ +//! The Bulletproofs range-proof implementation over Curve25519 Ristretto points. +//! +//! The implementation is based on the dalek-cryptography bulletproofs +//! [implementation](https://github.com/dalek-cryptography/bulletproofs). Compared to the original +//! implementation by dalek-cryptography: +//! - This implementation focuses on the range proof implementation, while the dalek-cryptography +//! crate additionally implements the general bulletproofs implementation for languages that can be +//! represented by arithmetic circuits as well as MPC. +//! - This implementation implements a non-interactive range proof aggregation that is specified in +//! the original Bulletproofs [paper](https://eprint.iacr.org/2017/1066) (Section 4.3). +//! + +#![allow(dead_code)] + +#[cfg(not(target_os = "solana"))] +use { + crate::encryption::pedersen::{Pedersen, PedersenCommitment, PedersenOpening}, + crate::{ + encryption::pedersen::{G, H}, + range_proof::{ + errors::{RangeProofGenerationError, RangeProofVerificationError}, + generators::RangeProofGens, + inner_product::InnerProductProof, + }, + transcript::TranscriptProtocol, + }, + core::iter, + curve25519_dalek::traits::MultiscalarMul, + curve25519_dalek::{ + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + traits::{IsIdentity, VartimeMultiscalarMul}, + }, + merlin::Transcript, + rand::rngs::OsRng, + subtle::{Choice, ConditionallySelectable}, +}; + +pub mod errors; +#[cfg(not(target_os = "solana"))] +pub mod generators; +#[cfg(not(target_os = "solana"))] +pub mod inner_product; +#[cfg(not(target_os = "solana"))] +pub mod util; + +#[allow(non_snake_case)] +#[cfg(not(target_os = "solana"))] +#[derive(Clone)] +pub struct RangeProof { + pub A: CompressedRistretto, // 32 bytes + pub S: CompressedRistretto, // 32 bytes + pub T_1: CompressedRistretto, // 32 bytes + pub T_2: CompressedRistretto, // 32 bytes + pub t_x: Scalar, // 32 bytes + pub t_x_blinding: Scalar, // 32 bytes + pub e_blinding: Scalar, // 32 bytes + pub ipp_proof: InnerProductProof, // 448 bytes for withdraw; 512 for transfer +} + +#[allow(non_snake_case)] +#[cfg(not(target_os = "solana"))] +impl RangeProof { + /// Create an aggregated range proof. + /// + /// The proof is created with respect to a vector of Pedersen commitments C_1, ..., C_m. The + /// method itself does not take in these commitments, but the values associated with the commitments: + /// - vector of committed amounts v_1, ..., v_m represented as u64 + /// - bit-lengths of the committed amounts + /// - Pedersen openings for each commitments + /// + /// The sum of the bit-lengths of the commitments amounts must be a power-of-two + #[allow(clippy::many_single_char_names)] + #[cfg(not(target_os = "solana"))] + pub fn new( + amounts: Vec, + bit_lengths: Vec, + openings: Vec<&PedersenOpening>, + transcript: &mut Transcript, + ) -> Result { + // amounts, bit-lengths, openings must be same length vectors + let m = amounts.len(); + if bit_lengths.len() != m || openings.len() != m { + return Err(RangeProofGenerationError::VectorLengthMismatch); + } + + // each bit length must be greater than 0 for the proof to make sense + if bit_lengths + .iter() + .any(|bit_length| *bit_length == 0 || *bit_length > u64::BITS as usize) + { + return Err(RangeProofGenerationError::InvalidBitSize); + } + + // total vector dimension to compute the ultimate inner product proof for + let nm: usize = bit_lengths.iter().sum(); + if !nm.is_power_of_two() { + return Err(RangeProofGenerationError::VectorLengthMismatch); + } + + let bp_gens = RangeProofGens::new(nm) + .map_err(|_| RangeProofGenerationError::MaximumGeneratorLengthExceeded)?; + + // bit-decompose values and generate their Pedersen vector commitment + let a_blinding = Scalar::random(&mut OsRng); + let mut A = a_blinding * &(*H); + + let mut gens_iter = bp_gens.G(nm).zip(bp_gens.H(nm)); + for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) { + for j in 0..(*n_i) { + let (G_ij, H_ij) = gens_iter.next().unwrap(); + + // `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore, + // casting is lossless and right shift can be safely unwrapped + let v_ij = Choice::from((amount_i.checked_shr(j as u32).unwrap() & 1) as u8); + let mut point = -H_ij; + point.conditional_assign(G_ij, v_ij); + A += point; + } + } + let A = A.compress(); + + // generate blinding factors and generate their Pedersen vector commitment + let s_L: Vec = (0..nm).map(|_| Scalar::random(&mut OsRng)).collect(); + let s_R: Vec = (0..nm).map(|_| Scalar::random(&mut OsRng)).collect(); + + // generate blinding factor for Pedersen commitment; `s_blinding` should not to be confused + // with blinding factors for the actual inner product vector + let s_blinding = Scalar::random(&mut OsRng); + + let S = RistrettoPoint::multiscalar_mul( + iter::once(&s_blinding).chain(s_L.iter()).chain(s_R.iter()), + iter::once(&(*H)).chain(bp_gens.G(nm)).chain(bp_gens.H(nm)), + ) + .compress(); + + // add the Pedersen vector commitments to the transcript (send the commitments to the verifier) + transcript.append_point(b"A", &A); + transcript.append_point(b"S", &S); + + // derive challenge scalars from the transcript (receive challenge from the verifier): `y` + // and `z` used for merge multiple inner product relations into one single inner product + let y = transcript.challenge_scalar(b"y"); + let z = transcript.challenge_scalar(b"z"); + + // define blinded vectors: + // - l(x) = (a_L - z*1) + s_L*x + // - r(x) = (y^n * (a_R + z*1) + [z^2*2^n | z^3*2^n | ... | z^m*2^n]) + y^n * s_R*x + let mut l_poly = util::VecPoly1::zero(nm); + let mut r_poly = util::VecPoly1::zero(nm); + + let mut i = 0; + let mut exp_z = z * z; + let mut exp_y = Scalar::one(); + + for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) { + let mut exp_2 = Scalar::one(); + + for j in 0..(*n_i) { + // `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore, + // casting is lossless and right shift can be safely unwrapped + let a_L_j = Scalar::from(amount_i.checked_shr(j as u32).unwrap() & 1); + let a_R_j = a_L_j - Scalar::one(); + + l_poly.0[i] = a_L_j - z; + l_poly.1[i] = s_L[i]; + r_poly.0[i] = exp_y * (a_R_j + z) + exp_z * exp_2; + r_poly.1[i] = exp_y * s_R[i]; + + exp_y *= y; + exp_2 = exp_2 + exp_2; + + // `i` is capped by the sum of vectors in `bit_lengths` + i = i.checked_add(1).unwrap(); + } + exp_z *= z; + } + + // define t(x) = = t_0 + t_1*x + t_2*x + let t_poly = l_poly + .inner_product(&r_poly) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + + // generate Pedersen commitment for the coefficients t_1 and t_2 + let (T_1, t_1_blinding) = Pedersen::new(t_poly.1); + let (T_2, t_2_blinding) = Pedersen::new(t_poly.2); + + let T_1 = T_1.get_point().compress(); + let T_2 = T_2.get_point().compress(); + + transcript.append_point(b"T_1", &T_1); + transcript.append_point(b"T_2", &T_2); + + // evaluate t(x) on challenge x and homomorphically compute the openings for + // z^2 * V_1 + z^3 * V_2 + ... + z^{m+1} * V_m + delta(y, z)*G + x*T_1 + x^2*T_2 + let x = transcript.challenge_scalar(b"x"); + + let mut agg_opening = Scalar::zero(); + let mut exp_z = z; + for opening in openings { + exp_z *= z; + agg_opening += exp_z * opening.get_scalar(); + } + + let t_blinding_poly = util::Poly2( + agg_opening, + *t_1_blinding.get_scalar(), + *t_2_blinding.get_scalar(), + ); + + let t_x = t_poly.eval(x); + let t_x_blinding = t_blinding_poly.eval(x); + + transcript.append_scalar(b"t_x", &t_x); + transcript.append_scalar(b"t_x_blinding", &t_x_blinding); + + // homomorphically compuate the openings for A + x*S + let e_blinding = a_blinding + s_blinding * x; + let l_vec = l_poly.eval(x); + let r_vec = r_poly.eval(x); + + transcript.append_scalar(b"e_blinding", &e_blinding); + + // compute the inner product argument on the commitment: + // P = + + *Q + let w = transcript.challenge_scalar(b"w"); + let Q = w * &(*G); + + let G_factors: Vec = iter::repeat(Scalar::one()).take(nm).collect(); + let H_factors: Vec = util::exp_iter(y.invert()).take(nm).collect(); + + // generate challenge `c` for consistency with the verifier's transcript + transcript.challenge_scalar(b"c"); + + let ipp_proof = InnerProductProof::new( + &Q, + &G_factors, + &H_factors, + bp_gens.G(nm).cloned().collect(), + bp_gens.H(nm).cloned().collect(), + l_vec, + r_vec, + transcript, + )?; + + Ok(RangeProof { + A, + S, + T_1, + T_2, + t_x, + t_x_blinding, + e_blinding, + ipp_proof, + }) + } + + #[allow(clippy::many_single_char_names)] + pub fn verify( + &self, + comms: Vec<&PedersenCommitment>, + bit_lengths: Vec, + transcript: &mut Transcript, + ) -> Result<(), RangeProofVerificationError> { + // commitments and bit-lengths must be same length vectors + if comms.len() != bit_lengths.len() { + return Err(RangeProofVerificationError::VectorLengthMismatch); + } + + let m = bit_lengths.len(); + let nm: usize = bit_lengths.iter().sum(); + let bp_gens = RangeProofGens::new(nm) + .map_err(|_| RangeProofVerificationError::MaximumGeneratorLengthExceeded)?; + + if !nm.is_power_of_two() { + return Err(RangeProofVerificationError::InvalidBitSize); + } + + // append proof data to transcript and derive appropriate challenge scalars + transcript.validate_and_append_point(b"A", &self.A)?; + transcript.validate_and_append_point(b"S", &self.S)?; + + let y = transcript.challenge_scalar(b"y"); + let z = transcript.challenge_scalar(b"z"); + + let zz = z * z; + let minus_z = -z; + + transcript.validate_and_append_point(b"T_1", &self.T_1)?; + transcript.validate_and_append_point(b"T_2", &self.T_2)?; + + let x = transcript.challenge_scalar(b"x"); + + transcript.append_scalar(b"t_x", &self.t_x); + transcript.append_scalar(b"t_x_blinding", &self.t_x_blinding); + transcript.append_scalar(b"e_blinding", &self.e_blinding); + + let w = transcript.challenge_scalar(b"w"); + let c = transcript.challenge_scalar(b"c"); // challenge value for batching multiscalar mul + + // verify inner product proof + let (x_sq, x_inv_sq, s) = self.ipp_proof.verification_scalars(nm, transcript)?; + let s_inv = s.iter().rev(); + + let a = self.ipp_proof.a; + let b = self.ipp_proof.b; + + // construct concat_z_and_2, an iterator of the values of + // z^0 * \vec(2)^n || z^1 * \vec(2)^n || ... || z^(m-1) * \vec(2)^n + let concat_z_and_2: Vec = util::exp_iter(z) + .zip(bit_lengths.iter()) + .flat_map(|(exp_z, n_i)| { + util::exp_iter(Scalar::from(2u64)) + .take(*n_i) + .map(move |exp_2| exp_2 * exp_z) + }) + .collect(); + + let gs = s.iter().map(|s_i| minus_z - a * s_i); + let hs = s_inv + .zip(util::exp_iter(y.invert())) + .zip(concat_z_and_2.iter()) + .map(|((s_i_inv, exp_y_inv), z_and_2)| z + exp_y_inv * (zz * z_and_2 - b * s_i_inv)); + + let basepoint_scalar = + w * (self.t_x - a * b) + c * (delta(&bit_lengths, &y, &z) - self.t_x); + let value_commitment_scalars = util::exp_iter(z).take(m).map(|z_exp| c * zz * z_exp); + + let mega_check = RistrettoPoint::optional_multiscalar_mul( + iter::once(Scalar::one()) + .chain(iter::once(x)) + .chain(iter::once(c * x)) + .chain(iter::once(c * x * x)) + .chain(iter::once(-self.e_blinding - c * self.t_x_blinding)) + .chain(iter::once(basepoint_scalar)) + .chain(x_sq.iter().cloned()) + .chain(x_inv_sq.iter().cloned()) + .chain(gs) + .chain(hs) + .chain(value_commitment_scalars), + iter::once(self.A.decompress()) + .chain(iter::once(self.S.decompress())) + .chain(iter::once(self.T_1.decompress())) + .chain(iter::once(self.T_2.decompress())) + .chain(iter::once(Some(*H))) + .chain(iter::once(Some(*G))) + .chain(self.ipp_proof.L_vec.iter().map(|L| L.decompress())) + .chain(self.ipp_proof.R_vec.iter().map(|R| R.decompress())) + .chain(bp_gens.G(nm).map(|&x| Some(x))) + .chain(bp_gens.H(nm).map(|&x| Some(x))) + .chain(comms.iter().map(|V| Some(*V.get_point()))), + ) + .ok_or(RangeProofVerificationError::MultiscalarMul)?; + + if mega_check.is_identity() { + Ok(()) + } else { + Err(RangeProofVerificationError::AlgebraicRelation) + } + } + + // Following the dalek rangeproof library signature for now. The exact method signature can be + // changed. + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::with_capacity(7 * 32 + self.ipp_proof.serialized_size()); + buf.extend_from_slice(self.A.as_bytes()); + buf.extend_from_slice(self.S.as_bytes()); + buf.extend_from_slice(self.T_1.as_bytes()); + buf.extend_from_slice(self.T_2.as_bytes()); + buf.extend_from_slice(self.t_x.as_bytes()); + buf.extend_from_slice(self.t_x_blinding.as_bytes()); + buf.extend_from_slice(self.e_blinding.as_bytes()); + buf.extend_from_slice(&self.ipp_proof.to_bytes()); + buf + } + + // Following the dalek rangeproof library signature for now. The exact method signature can be + // changed. + pub fn from_bytes(slice: &[u8]) -> Result { + if slice.len() % 32 != 0 { + return Err(RangeProofVerificationError::Deserialization); + } + if slice.len() < 7 * 32 { + return Err(RangeProofVerificationError::Deserialization); + } + + let A = CompressedRistretto(util::read32(&slice[0..])); + let S = CompressedRistretto(util::read32(&slice[32..])); + let T_1 = CompressedRistretto(util::read32(&slice[2 * 32..])); + let T_2 = CompressedRistretto(util::read32(&slice[3 * 32..])); + + let t_x = Scalar::from_canonical_bytes(util::read32(&slice[4 * 32..])) + .ok_or(RangeProofVerificationError::Deserialization)?; + let t_x_blinding = Scalar::from_canonical_bytes(util::read32(&slice[5 * 32..])) + .ok_or(RangeProofVerificationError::Deserialization)?; + let e_blinding = Scalar::from_canonical_bytes(util::read32(&slice[6 * 32..])) + .ok_or(RangeProofVerificationError::Deserialization)?; + + let ipp_proof = InnerProductProof::from_bytes(&slice[7 * 32..])?; + + Ok(RangeProof { + A, + S, + T_1, + T_2, + t_x, + t_x_blinding, + e_blinding, + ipp_proof, + }) + } +} + +/// Compute +/// \\[ +/// \delta(y,z) = (z - z^{2}) \langle \mathbf{1}, {\mathbf{y}}^{n \cdot m} \rangle - \sum_{j=0}^{m-1} z^{j+3} \cdot \langle \mathbf{1}, {\mathbf{2}}^{n \cdot m} \rangle +/// \\] +#[cfg(not(target_os = "solana"))] +fn delta(bit_lengths: &[usize], y: &Scalar, z: &Scalar) -> Scalar { + let nm: usize = bit_lengths.iter().sum(); + let sum_y = util::sum_of_powers(y, nm); + + let mut agg_delta = (z - z * z) * sum_y; + let mut exp_z = z * z * z; + for n_i in bit_lengths.iter() { + let sum_2 = util::sum_of_powers(&Scalar::from(2u64), *n_i); + agg_delta -= exp_z * sum_2; + exp_z *= z; + } + agg_delta +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_rangeproof() { + let (comm, open) = Pedersen::new(55_u64); + + let mut transcript_create = Transcript::new(b"Test"); + let mut transcript_verify = Transcript::new(b"Test"); + + let proof = + RangeProof::new(vec![55], vec![32], vec![&open], &mut transcript_create).unwrap(); + + assert!(proof + .verify(vec![&comm], vec![32], &mut transcript_verify) + .is_ok()); + } + + #[test] + fn test_aggregated_rangeproof() { + let (comm_1, open_1) = Pedersen::new(55_u64); + let (comm_2, open_2) = Pedersen::new(77_u64); + let (comm_3, open_3) = Pedersen::new(99_u64); + + let mut transcript_create = Transcript::new(b"Test"); + let mut transcript_verify = Transcript::new(b"Test"); + + let proof = RangeProof::new( + vec![55, 77, 99], + vec![64, 32, 32], + vec![&open_1, &open_2, &open_3], + &mut transcript_create, + ) + .unwrap(); + + assert!(proof + .verify( + vec![&comm_1, &comm_2, &comm_3], + vec![64, 32, 32], + &mut transcript_verify, + ) + .is_ok()); + } + + // TODO: write test for serialization/deserialization +} diff --git a/zk-sdk/src/range_proof/util.rs b/zk-sdk/src/range_proof/util.rs new file mode 100644 index 00000000000000..3d2b50ccb6e993 --- /dev/null +++ b/zk-sdk/src/range_proof/util.rs @@ -0,0 +1,139 @@ +/// Utility functions for Bulletproofs. +/// +/// The code is adapted from the `utility` module in the dalek bulletproof implementation +/// https://github.com/dalek-cryptography/bulletproofs. +use curve25519_dalek::scalar::Scalar; + +/// Represents a degree-1 vector polynomial \\(\mathbf{a} + \mathbf{b} \cdot x\\). +pub struct VecPoly1(pub Vec, pub Vec); + +impl VecPoly1 { + pub fn zero(n: usize) -> Self { + VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n]) + } + + pub fn inner_product(&self, rhs: &VecPoly1) -> Option { + // Uses Karatsuba's method + let l = self; + let r = rhs; + + let t0 = inner_product(&l.0, &r.0)?; + let t2 = inner_product(&l.1, &r.1)?; + + let l0_plus_l1 = add_vec(&l.0, &l.1); + let r0_plus_r1 = add_vec(&r.0, &r.1); + + let t1 = inner_product(&l0_plus_l1, &r0_plus_r1)? - t0 - t2; + + Some(Poly2(t0, t1, t2)) + } + + pub fn eval(&self, x: Scalar) -> Vec { + let n = self.0.len(); + let mut out = vec![Scalar::zero(); n]; + #[allow(clippy::needless_range_loop)] + for i in 0..n { + out[i] = self.0[i] + self.1[i] * x; + } + out + } +} + +/// Represents a degree-2 scalar polynomial \\(a + b \cdot x + c \cdot x^2\\) +pub struct Poly2(pub Scalar, pub Scalar, pub Scalar); + +impl Poly2 { + pub fn eval(&self, x: Scalar) -> Scalar { + self.0 + x * (self.1 + x * self.2) + } +} + +/// Provides an iterator over the powers of a `Scalar`. +/// +/// This struct is created by the `exp_iter` function. +pub struct ScalarExp { + x: Scalar, + next_exp_x: Scalar, +} + +impl Iterator for ScalarExp { + type Item = Scalar; + + fn next(&mut self) -> Option { + let exp_x = self.next_exp_x; + self.next_exp_x *= self.x; + Some(exp_x) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} + +/// Return an iterator of the powers of `x`. +pub fn exp_iter(x: Scalar) -> ScalarExp { + let next_exp_x = Scalar::one(); + ScalarExp { x, next_exp_x } +} + +pub fn add_vec(a: &[Scalar], b: &[Scalar]) -> Vec { + if a.len() != b.len() { + // throw some error + //println!("lengths of vectors don't match for vector addition"); + } + let mut out = vec![Scalar::zero(); b.len()]; + for i in 0..a.len() { + out[i] = a[i] + b[i]; + } + out +} + +/// Given `data` with `len >= 32`, return the first 32 bytes. +pub fn read32(data: &[u8]) -> [u8; 32] { + let mut buf32 = [0u8; 32]; + buf32[..].copy_from_slice(&data[..32]); + buf32 +} + +/// Computes an inner product of two vectors +/// \\[ +/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i. +/// \\] +/// Errors if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal. +pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Option { + let mut out = Scalar::zero(); + if a.len() != b.len() { + return None; + } + for i in 0..a.len() { + out += a[i] * b[i]; + } + Some(out) +} + +/// Takes the sum of all the powers of `x`, up to `n` +/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplications and additions. +/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions. +/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2. +pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar { + if !n.is_power_of_two() { + return sum_of_powers_slow(x, n); + } + if n == 0 || n == 1 { + return Scalar::from(n as u64); + } + let mut m = n; + let mut result = Scalar::one() + x; + let mut factor = *x; + while m > 2 { + factor = factor * factor; + result = result + factor * result; + m /= 2; + } + result +} + +// takes the sum of all of the powers of x, up to n +fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar { + exp_iter(*x).take(n).sum() +}