diff --git a/crates/components/authdecode/authdecode-core/Cargo.toml b/crates/components/authdecode/authdecode-core/Cargo.toml new file mode 100644 index 0000000000..9e87d79d68 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "tlsn-authdecode-core" +authors = ["TLSNotary Team"] +description = "A 2PC protocol for authenticated decoding of encodings in zk" +keywords = ["tls", "mpc", "2pc"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0" +edition = "2021" + +[lib] +name = "authdecode_core" + +[features] +default = [] +fixtures = ["dep:rand_chacha"] +mock = ["dep:bincode", "dep:num", "dep:blake3"] +tracing = ["dep:tracing"] + +[dependencies] +poseidon-halo2 = { workspace = true } + +bincode = { workspace = true, optional = true} +blake3 = { workspace = true, optional = true} +cfg-if = "1" +enum-try-as-inner = { workspace = true } +ff = "0.13" +getset = "0.1.2" +group = "0.13" +halo2_poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon-gadget", rev="764a682"} +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v0.3.0", default-features = false} +itybity = "0.2" +lazy_static = "1.4" +num = { version = "0.4.1", optional = true} +opaque-debug = { workspace = true } +rand = { workspace = true } +rand_chacha = { workspace = true, optional = true} +serde = { workspace = true, features = ["derive"] } +thiserror = { workspace = true } +tracing = { workspace = true, optional = true } + +[dev-dependencies] +bincode = { workspace = true } +blake3 = { workspace = true } +criterion = { workspace = true } +hex = { workspace = true } +num = "0.4.1" +rand_chacha = { workspace = true } +rand_core = { workspace = true } +rstest = { workspace = true } diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/circuit.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/circuit.rs new file mode 100644 index 0000000000..a196ece54c --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/circuit.rs @@ -0,0 +1,606 @@ +use halo2_poseidon::poseidon::{primitives::ConstantLength, Hash, Pow5Chip, Pow5Config}; +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region, SimpleFloorPlanner, Value}, + halo2curves::bn256::Fr as F, + plonk::{ + Advice, Circuit, Column, ConstraintSystem, Constraints, Error, Expression, Instance, + Selector, + }, + poly::Rotation, +}; +use std::convert::TryInto; + +use crate::backend::halo2::{ + poseidon::{configure_poseidon_rate_15, configure_poseidon_rate_2}, + utils::{compose_bits, f_to_bits}, +}; + +use poseidon_halo2::{Spec15, Spec2}; + +// Rationale for the selection of constants. +// +// In order to optimize the proof generation time, the circuit should contain as few instance +// columns as possible and also as few rows as possible. +// The circuit has [super::CHUNK_SIZE] public inputs (deltas) which must be placed into instance +// columns. It was empirically established that 58 rows and 64 instance columns provides the best +// performance. +// +// Note that 58 usable rows is what we get when we set the circuit's K to 6 (halo2 reserves 6 rows +// for internal purposes, so we get 2^6-6 usable rows). + +/// How many field elements to use to pack the plaintext into. Only [USABLE_BYTES] of each field +/// element will be used. +pub const FIELD_ELEMENTS: usize = 14; + +/// How many least significant bytes of a field element to use to pack the plaintext into. +pub const USABLE_BYTES: usize = 31; + +/// How many bits there are in one limb of a plaintext field element. +// +// Note that internally the bits of one plaintext field element are zero-padded on the left to a +// total of 256 bits. Then the bits are split up into 4 limbs of [BITS_PER_LIMB] bits. +pub const BITS_PER_LIMB: usize = 64; + +/// Bytesize of the salt used both in the plaintext commitment and encoding sum commitment. +pub const SALT_SIZE: usize = 16; + +#[derive(Clone, Debug)] +/// The circuit configuration. +pub struct CircuitConfig { + /// Columns containing plaintext bits. + bits: [Column; BITS_PER_LIMB], + /// Scratch space used to calculate intermediate values. + scratch_space: [Column; 5], + /// Expected dot product of a vector of deltas and a vector of a limb's bits. + dot_product: Column, + /// Expected value when composing a 64-bit limb into a field element. + expected_composed_limbs: Column, + /// The first and the second rows of this column are used to store the plaintext salt and the + /// encoding sum salt, resp. + salt: Column, + + /// Columns of deltas, arranged such that each row of deltas corresponds to one limb of plaintext. + deltas: [Column; BITS_PER_LIMB], + + /// Since halo2 does not allow to constrain inputs in instance columns directly, we first need + /// to copy the inputs into this advice column. + advice_from_instance: Column, + + // SELECTORS. + // A selector activates a gate with a similar name, e.g. "selector_dot_product" activates the + // gate "dot_product" etc. + selector_dot_product: Selector, + selector_binary_check: Selector, + selector_compose_limb: [Selector; 4], + selector_sum: Selector, + selector_eight_bits_zero: Selector, + + /// Config for rate-15 Poseidon. + poseidon_config_rate15: Pow5Config, + /// Config for rate-2 Poseidon. + poseidon_config_rate2: Pow5Config, + + /// Contains the following public inputs in this order: (plaintext hash, encoding sum hash, + /// zero sum). + public_inputs: Column, +} + +#[derive(Clone, Debug)] +/// The AuthDecode circuit. +pub struct AuthDecodeCircuit { + /// The bits of plaintext which was committed to. Each bit is represente as a field element. + /// + /// The plaintext consist of [FIELD_ELEMENTS] field elements. Each field element is split + /// into 4 limbs of [BITS_PER_LIMB] bits. The high limb has the index 0. + pub plaintext: [[[F; BITS_PER_LIMB]; 4]; FIELD_ELEMENTS], + /// The salt used to create a plaintext commitment. + pub plaintext_salt: F, + /// The salt used to create an encoding sum commitment. + pub encoding_sum_salt: F, +} + +impl Circuit for AuthDecodeCircuit { + type Config = CircuitConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + plaintext: [[[F::default(); BITS_PER_LIMB]; 4]; FIELD_ELEMENTS], + plaintext_salt: F::default(), + encoding_sum_salt: F::default(), + } + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + // ADVICE COLUMNS + + let bits: [Column; BITS_PER_LIMB] = (0..BITS_PER_LIMB) + .map(|_| meta.advice_column()) + .collect::>() + .try_into() + .unwrap(); + + let dot_product = meta.advice_column(); + meta.enable_equality(dot_product); + + let expected_limbs = meta.advice_column(); + meta.enable_equality(expected_limbs); + + let salt = meta.advice_column(); + meta.enable_equality(salt); + + let scratch_space: [Column; 5] = (0..5) + .map(|_| { + let c = meta.advice_column(); + meta.enable_equality(c); + c + }) + .collect::>() + .try_into() + .unwrap(); + + let advice_from_instance = meta.advice_column(); + meta.enable_equality(advice_from_instance); + + // INSTANCE COLUMNS + + let deltas: [Column; BITS_PER_LIMB] = (0..BITS_PER_LIMB) + .map(|_| meta.instance_column()) + .collect::>() + .try_into() + .unwrap(); + + let public_inputs = meta.instance_column(); + meta.enable_equality(public_inputs); + + // SELECTORS + + let selector_dot_product = meta.selector(); + let selector_binary_check = meta.selector(); + let selector_compose: [Selector; 4] = (0..4) + .map(|_| meta.selector()) + .collect::>() + .try_into() + .unwrap(); + let selector_sum = meta.selector(); + let selector_eight_bits_zero = meta.selector(); + + // POSEIDON + + let poseidon_config_rate15 = configure_poseidon_rate_15::(15, meta); + let poseidon_config_rate2 = configure_poseidon_rate_2::(2, meta); + // We need to have one column for global constants which the Poseidon chip requires. + let global_constants = meta.fixed_column(); + meta.enable_constant(global_constants); + + // Put everything initialized above into a config. + let cfg = CircuitConfig { + bits, + scratch_space, + dot_product, + expected_composed_limbs: expected_limbs, + salt, + advice_from_instance, + + deltas, + + selector_dot_product, + selector_compose_limb: selector_compose, + selector_binary_check, + selector_sum, + selector_eight_bits_zero, + + poseidon_config_rate15, + poseidon_config_rate2, + + public_inputs, + }; + + // MISC + + // Build `Expression`s containing powers of 2 from the 0th to the 255th power. + let mut pow_2_x: Vec = Vec::with_capacity(256); + let two = F::one() + F::one(); + // Push 2^0. + pow_2_x.push(F::one()); + + for n in 1..256 { + // Push 2^n. + pow_2_x.push(pow_2_x[n - 1] * two); + } + + let pow_2_x = pow_2_x + .into_iter() + .map(Expression::Constant) + .collect::>(); + + // GATES + + // Computes the dot product of a vector of deltas and a vector of a limb's bits and + // constrains it to match the expected dot product. + meta.create_gate("dot_product", |meta| { + let mut product = Expression::Constant(F::zero()); + + for i in 0..BITS_PER_LIMB { + let delta = meta.query_instance(cfg.deltas[i], Rotation::cur()); + let bit = meta.query_advice(cfg.bits[i], Rotation::cur()); + product = product + delta * bit; + } + + // Constrain to match the expected dot product. + let expected = meta.query_advice(cfg.dot_product, Rotation::cur()); + let sel = meta.query_selector(cfg.selector_dot_product); + vec![sel * (product - expected)] + }); + + // Constrains each bit of a limb to be binary. + meta.create_gate("binary_check", |meta| { + // Create an `Expression` for each bit. + let expressions: [Expression; BITS_PER_LIMB] = (0..BITS_PER_LIMB) + .map(|i| { + let bit = meta.query_advice(cfg.bits[i], Rotation::cur()); + bit.clone() * bit.clone() - bit + }) + .collect::>() + .try_into() + .unwrap(); + let sel = meta.query_selector(cfg.selector_binary_check); + + // Constrain all expressions to be equal to 0. + Constraints::with_selector(sel, expressions) + }); + + // Create 4 gates for each of the 4 limbs of the plaintext bits, starting from the high limb. + for idx in 0..4 { + // Compose the bits of a limb into a field element, left-shifting if necessary and + // constrain the result to match the expected value. + meta.create_gate("compose_limb", |meta| { + let mut sum_total = Expression::Constant(F::zero()); + + for i in 0..BITS_PER_LIMB { + // The first bit is the highest bit. It is multiplied by the + // highest power of 2 for that limb. + let bit = meta.query_advice(cfg.bits[i], Rotation::cur()); + sum_total = sum_total + bit * pow_2_x[255 - (BITS_PER_LIMB * idx) - i].clone(); + } + + // Constrain to match the expected limb value. + let expected = meta.query_advice(cfg.expected_composed_limbs, Rotation::cur()); + let sel = meta.query_selector(cfg.selector_compose_limb[idx]); + vec![sel * (sum_total - expected)] + }); + } + + // Sums 4 cells in the scratch space and constrains the sum to equal the expected value. + meta.create_gate("sum", |meta| { + let mut sum = Expression::Constant(F::zero()); + + for i in 0..4 { + let value = meta.query_advice(cfg.scratch_space[i], Rotation::cur()); + sum = sum + value; + } + + // Constrain to match the expected sum. + let expected = meta.query_advice(cfg.scratch_space[4], Rotation::cur()); + let sel = meta.query_selector(cfg.selector_sum); + vec![sel * (sum - expected)] + }); + + // Constrains 8 most significant bits of a limb to be zero. + meta.create_gate("eight_bits_zero", |meta| { + let expressions: [Expression; 8] = (0..8) + .map(|i| meta.query_advice(cfg.bits[i], Rotation::cur())) + .collect::>() + .try_into() + .unwrap(); + + let sel = meta.query_selector(cfg.selector_eight_bits_zero); + + // Constrain all expressions to be equal to 0. + Constraints::with_selector(sel, expressions) + }); + + cfg + } + + fn synthesize(&self, cfg: Self::Config, mut layouter: impl Layouter) -> Result<(), Error> { + let ( + expected_plaintext_hash, + expected_encoding_sum_hash, + zero_sum, + plaintext_salt, + encoding_sum_salt, + ) = layouter.assign_region( + || "assign advice from instance", + |mut region| { + let expected_plaintext_hash = region.assign_advice_from_instance( + || "assign plaintext hash", + cfg.public_inputs, + 0, + cfg.advice_from_instance, + 0, + )?; + + let expected_encoding_sum_hash = region.assign_advice_from_instance( + || "assign encoding sum hash", + cfg.public_inputs, + 1, + cfg.advice_from_instance, + 1, + )?; + + let zero_sum = region.assign_advice_from_instance( + || "assign zero sum", + cfg.public_inputs, + 2, + cfg.advice_from_instance, + 2, + )?; + + let plaintext_salt = region.assign_advice( + || "assign plaintext salt", + cfg.salt, + 0, + || Value::known(self.plaintext_salt), + )?; + + let encoding_sum_salt = region.assign_advice( + || "assign encoding sum salt", + cfg.salt, + 1, + || Value::known(self.encoding_sum_salt), + )?; + + Ok(( + expected_plaintext_hash, + expected_encoding_sum_hash, + zero_sum, + plaintext_salt, + encoding_sum_salt, + )) + }, + )?; + + let (mut plaintext, encoding_sum) = layouter.assign_region( + || "compose plaintext and compute encoding sum", + |mut region| { + // Plaintext field elements composed from bits. + let mut plaintext = Vec::new(); + + // A dot product of one field element's bits with the corresponding deltas. + let mut dot_products = Vec::new(); + + // Row offset of the scratch space. + let mut offset = 0; + + // Process 4 limbs of one field element of the plaintext at a time. + for (field_element_idx, limbs) in self.plaintext.iter().enumerate() { + // Expected values of limbs composed from bits. + let mut expected_limbs = Vec::with_capacity(4); + + // Expected dot product for each vector of limb's bits and a corresponding vector + // of deltas. + let mut expected_dot_products = Vec::with_capacity(4); + + // Process one limb's bits at a time. + for (limb_idx, limb_bits) in limbs.iter().enumerate() { + // The index of the row where the bits and deltas are located. + let row_idx = field_element_idx * 4 + limb_idx; + + // Assign bits of a limb to the same row. + for (i, bit) in limb_bits.iter().enumerate() { + region.assign_advice( + || "assign limb bits", + cfg.bits[i], + row_idx, + || Value::known(*bit), + )?; + } + // Constrain the whole row of bits to be binary. + cfg.selector_binary_check.enable(&mut region, row_idx)?; + + if limb_idx == 0 { + // Constrain the high limb's MSBs to be zero. + cfg.selector_eight_bits_zero.enable(&mut region, row_idx)?; + } + + let expected_limb = compose_bits(limb_bits, limb_idx); + + // Assign the expected composed limb. + expected_limbs.push(region.assign_advice( + || "assign the expected composed limb", + cfg.expected_composed_limbs, + row_idx, + || Value::known(expected_limb), + )?); + + // Constrain the expected limb to match the value which the gate composes. + cfg.selector_compose_limb[limb_idx].enable(&mut region, row_idx)?; + + // Compute and assign the expected dot product for this row. + let mut expected_dot_product = Value::known(F::zero()); + for (i, bit) in limb_bits.iter().enumerate() { + let delta = region.instance_value(cfg.deltas[i], row_idx)?; + expected_dot_product = expected_dot_product + delta * Value::known(bit); + } + + expected_dot_products.push(region.assign_advice( + || "assign expected dot product", + cfg.dot_product, + row_idx, + || expected_dot_product, + )?); + + // Constrain the expected dot product to match the value which the gate + // computes. + cfg.selector_dot_product.enable(&mut region, row_idx)?; + } + + // Sum 4 limbs to get the plaintext field element. + plaintext.push(self.sum(&expected_limbs, &mut region, &cfg, &mut offset)?); + + // Sum 4 sub dot products to get the dot product of one field element's bits with + // the corresponding deltas. + dot_products.push(self.sum( + &expected_dot_products, + &mut region, + &cfg, + &mut offset, + )?); + } + + // Compute the sub sums for each chunk of 4 sub dot products. We will have 4 sub sums in total. + // XXX: This is hardcoded to 4 sub sums which is good enough if we have anywhere from 13 + // to 16 field elements of plaintext. + let four_sums = dot_products + .chunks(4) + .map(|chunk| self.sum(chunk, &mut region, &cfg, &mut offset)) + .collect::, Error>>()?; + + // Compute the final dot product. + let dot_product = self.sum(&four_sums, &mut region, &cfg, &mut offset)?; + + // Add zero sum and the final dot product to get encoding sum. + let encoding_sum = self.sum( + &[dot_product, zero_sum.clone()], + &mut region, + &cfg, + &mut offset, + )?; + + Ok((plaintext, encoding_sum)) + }, + )?; + + // Hash the salted encoding sum and constrain the digest to match the expected value. + let chip = Pow5Chip::construct(cfg.poseidon_config_rate2.clone()); + let hasher = Hash::, 3, 2>::init( + chip, + layouter.namespace(|| "init spec2 poseidon"), + )?; + + let output = hasher.hash( + layouter.namespace(|| "hash spec2 poseidon"), + vec![encoding_sum, encoding_sum_salt.clone()] + .try_into() + .unwrap(), + )?; + + layouter.assign_region( + || "constrain encoding sum digest", + |mut region| { + region.constrain_equal(output.cell(), expected_encoding_sum_hash.cell())?; + Ok(()) + }, + )?; + + // Hash the salted plaintext and constrain the digest to match the expected value. + plaintext.push(plaintext_salt.clone()); + + let chip = Pow5Chip::construct(cfg.poseidon_config_rate15.clone()); + + let hasher = Hash::, 16, 15>::init( + chip, + layouter.namespace(|| "init spec15 poseidon"), + )?; + // unwrap() is safe since we use exactly 15 field elements of plaintext. + let output = hasher.hash( + layouter.namespace(|| "hash spec15 poseidon"), + plaintext.try_into().unwrap(), + )?; + + layouter.assign_region( + || "constrain plaintext digest", + |mut region| { + region.constrain_equal(output.cell(), expected_plaintext_hash.cell())?; + Ok(()) + }, + )?; + + Ok(()) + } +} + +impl AuthDecodeCircuit { + /// Creates a new AuthDecode circuit. + pub fn new(plaintext: [F; FIELD_ELEMENTS], plaintext_salt: F, encoding_sum_salt: F) -> Self { + // Split each field element into 4 BITS_PER_LIMB-bit limbs. The high limb has index 0. + Self { + plaintext: plaintext + .into_iter() + .map(|f| { + f_to_bits(&f) + .into_iter() + // Convert each bit into a field element. + .map(F::from) + .collect::>() + .chunks(BITS_PER_LIMB) + .map(|chunk| chunk.try_into().unwrap()) + .collect::>() + .try_into() + .unwrap() + }) + .collect::>() + .try_into() + .unwrap(), + plaintext_salt, + encoding_sum_salt, + } + } + + /// Calculates the sum of values in `cells` and returns a cell constrained to equal the sum. + /// + /// # Arguments + /// * `cells` - The cells containing the values to be summed. + /// * `region` - The halo2 region. + /// * `config` - The circuit config. + /// * `row_offset` - The offset of the row in the scratch space on which the calculation + /// will be performed. + /// + /// # Panics + /// + /// Panics if the amount of `cells` is less than 2 or more than 4. + fn sum( + &self, + cells: &[AssignedCell], + region: &mut Region, + config: &CircuitConfig, + row_offset: &mut usize, + ) -> Result, Error> { + assert!(cells.len() <= 4 && cells.len() >= 2); + + let mut sum = Value::known(F::zero()); + // Copy the cells onto the same row and compute their sum. + for (i, cell) in cells.iter().enumerate() { + cell.copy_advice( + || "copying summands", + region, + config.scratch_space[i], + *row_offset, + )?; + sum = sum + cell.value(); + } + // If there were less that 4 cells to sum, constrain the unused cells to be 0. + for i in cells.len()..4 { + region.assign_advice_from_constant( + || "assigning zero values", + config.scratch_space[i], + *row_offset, + F::zero(), + )?; + } + + let assigned_sum = region.assign_advice( + || "assigning the sum", + config.scratch_space[4], + *row_offset, + || sum, + )?; + + config.selector_sum.enable(region, *row_offset)?; + + *row_offset += 1; + + Ok(assigned_sum) + } +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/fixtures.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/fixtures.rs new file mode 100644 index 0000000000..2e733986e0 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/fixtures.rs @@ -0,0 +1,118 @@ +use crate::{ + backend::{ + halo2::{ + circuit::USABLE_BYTES, + prepare_instance, + prover::{Prover, _prepare_circuit}, + verifier::Verifier, + Bn256F, PARAMS, + }, + traits::{Field, ProverBackend, VerifierBackend}, + }, + prover::ProverInput, + Proof, +}; + +use halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::Bn256, + poly::{commitment::Params, kzg::commitment::ParamsKZG}, +}; + +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + any::Any, + ops::{Add, Sub}, +}; + +/// Returns a pair of backends which use halo2's MockProver to prove and verify. +pub fn backend_pair_mock() -> (ProverBackendWrapper, VerifierBackendWrapper) { + let pair = backend_pair(); + ( + ProverBackendWrapper { + prover: Box::new(pair.0), + }, + VerifierBackendWrapper { + verifier: Box::new(pair.1), + }, + ) +} + +/// Returns a pair of zk backends which use halo2. +pub fn backend_pair() -> (Prover, Verifier) { + (Prover::new(), Verifier::new()) +} + +/// Returns the K parameter. +pub fn k() -> u32 { + ParamsKZG::::k(&PARAMS) +} + +// A wrapper of the prover backend which uses MockProver to prove and verify. +pub struct ProverBackendWrapper { + prover: Box>, +} + +impl ProverBackend for ProverBackendWrapper { + fn chunk_size(&self) -> usize { + self.prover.chunk_size() + } + + fn commit_encoding_sum(&self, encoding_sum: Bn256F) -> (Bn256F, Bn256F) { + self.prover.commit_encoding_sum(encoding_sum) + } + + fn commit_plaintext(&self, plaintext: Vec) -> (Bn256F, Bn256F) { + self.prover.commit_plaintext(plaintext) + } + + fn commit_plaintext_with_salt(&self, plaintext: Vec, salt: Bn256F) -> Bn256F { + self.prover.commit_plaintext_with_salt(plaintext, salt) + } + + fn prove( + &self, + input: Vec>, + ) -> Result, crate::prover::ProverError> { + _ = input + .into_iter() + .map(|input| { + let instance_columns = prepare_instance(input.public(), USABLE_BYTES); + let circuit = _prepare_circuit(input.private(), USABLE_BYTES); + + let prover = MockProver::run(k(), &circuit, instance_columns).unwrap(); + assert!(prover.verify().is_ok()); + }) + .collect::>(); + + // Return a dummy proof. + Ok(vec![Proof::new(&[0u8])]) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +// A wrapper of the verifier backend. +pub struct VerifierBackendWrapper { + verifier: Box>, +} + +impl VerifierBackend for VerifierBackendWrapper +where + F: Field + Add + Sub + Serialize + DeserializeOwned + Clone, +{ + fn chunk_size(&self) -> usize { + self.verifier.chunk_size() + } + + fn verify( + &self, + _inputs: Vec>, + _proofs: Vec, + ) -> Result<(), crate::verifier::VerifierError> { + // The proof has already been verified with MockProver::verify(). + Ok(()) + } +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/kzg_bn254_6.srs b/crates/components/authdecode/authdecode-core/src/backend/halo2/kzg_bn254_6.srs new file mode 100644 index 0000000000..f5f3699947 Binary files /dev/null and b/crates/components/authdecode/authdecode-core/src/backend/halo2/kzg_bn254_6.srs differ diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/mod.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/mod.rs new file mode 100644 index 0000000000..515d9be3bd --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/mod.rs @@ -0,0 +1,155 @@ +use crate::{ + backend::{ + halo2::{ + circuit::{BITS_PER_LIMB, FIELD_ELEMENTS}, + utils::{bytes_be_to_f, slice_to_columns}, + }, + traits::Field, + }, + PublicInput, +}; + +use halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr}, + poly::kzg::commitment::ParamsKZG, +}; + +use lazy_static::lazy_static; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::ops::{Add, Sub}; + +mod circuit; +pub mod onetimesetup; +pub mod poseidon; +pub mod prover; +mod utils; +pub mod verifier; + +#[cfg(any(test, feature = "fixtures"))] +pub mod fixtures; + +lazy_static! { + static ref PARAMS: ParamsKZG = onetimesetup::params(); +} + +/// The bytesize of one chunk of plaintext. +pub const CHUNK_SIZE: usize = circuit::FIELD_ELEMENTS * circuit::USABLE_BYTES; + +/// A field element of the Bn256 curve. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Bn256F { + #[serde(serialize_with = "fr_serialize", deserialize_with = "fr_deserialize")] + pub inner: Fr, +} +impl Bn256F { + /// Creates a new Bn256 field element. + pub fn new(inner: Fr) -> Self { + Self { inner } + } +} + +impl Field for Bn256F { + fn from_bytes_be(bytes: Vec) -> Self { + Self { + inner: bytes_be_to_f(bytes), + } + } + + fn to_bytes_be(self) -> Vec { + let mut le = self.inner.to_bytes(); + // Reverse from little-endian to big-endian. + le.reverse(); + le.to_vec() + } + + fn zero() -> Self { + Self { inner: Fr::zero() } + } +} + +impl Add for Bn256F { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + inner: self.inner + rhs.inner, + } + } +} + +impl Sub for Bn256F { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self { + inner: self.inner - rhs.inner, + } + } +} + +impl From for Bn256F { + fn from(value: Fr) -> Self { + Bn256F::new(value) + } +} + +#[allow(clippy::from_over_into)] +impl Into for &Bn256F { + fn into(self) -> Fr { + self.inner + } +} + +// Serializes the `Fr` type into bytes. +fn fr_serialize(fr: &Fr, serializer: S) -> Result +where + S: Serializer, +{ + serializer.serialize_bytes(&fr.to_bytes()) +} + +// Deserializes the `Fr` type from bytes. +fn fr_deserialize<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let bytes: [u8; 32] = Vec::deserialize(deserializer)? + .try_into() + .map_err(|_| serde::de::Error::custom("the amount of bytes is not 32"))?; + + let res = Fr::from_bytes(&bytes); + if res.is_none().into() { + return Err(serde::de::Error::custom( + "the bytes are not a valid field element", + )); + } + Ok(res.unwrap()) +} + +/// Prepares instance columns. +#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all))] +fn prepare_instance(input: &PublicInput, usable_bytes: usize) -> Vec> { + let deltas = input + .deltas + .iter() + .map(|f: &Bn256F| f.inner) + .collect::>(); + + // Arrange deltas in instance columns. + let mut instance_columns = slice_to_columns( + &deltas, + usable_bytes * 8, + BITS_PER_LIMB * 4, + FIELD_ELEMENTS * 4, + BITS_PER_LIMB, + ); + + // Add another column with public inputs. + instance_columns.push(vec![ + input.plaintext_hash.inner, + input.encoding_sum_hash.inner, + input.zero_sum.inner, + ]); + + instance_columns +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/onetimesetup.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/onetimesetup.rs new file mode 100644 index 0000000000..13b9ca6da0 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/onetimesetup.rs @@ -0,0 +1,45 @@ +// A one-time setup generates the proving key and the verification key. The keys are deterministic, so +// they can be cached and re-used for all future proof generation and verification. + +use crate::backend::halo2::{ + circuit::{AuthDecodeCircuit, FIELD_ELEMENTS}, + PARAMS, +}; + +use halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr as F, G1Affine}, + plonk, + plonk::{ProvingKey, VerifyingKey}, + poly::{commitment::Params, kzg::commitment::ParamsKZG}, +}; + +/// Returns the verification key for the AuthDecode circuit. +pub fn verification_key() -> VerifyingKey { + // It is safe to `unwrap` since we are inputting deterministic params and circuit. + plonk::keygen_vk(&PARAMS.clone(), &circuit_instance()).unwrap() +} + +/// Returns the proving key for the AuthDecode circuit. +pub fn proving_key() -> ProvingKey { + // It is safe to `unwrap` since we are inputting deterministic params and circuit. + plonk::keygen_pk(&PARAMS.clone(), verification_key(), &circuit_instance()).unwrap() +} + +/// Returns the parameters used to generate the proving and the verification key. +pub(crate) fn params() -> ParamsKZG { + // Parameters were taken from Axiom's trusted setup described here: + // https://docs.axiom.xyz/docs/transparency-and-security/kzg-trusted-setup , + // located at https://axiom-crypto.s3.amazonaws.com/challenge_0085/kzg_bn254_15.srs + // + // They were downsized by calling `ParamsKZG::downsize(6)` with v0.3.0 of + // https://github.com/privacy-scaling-explorations/halo2 + + let bytes = include_bytes!("kzg_bn254_6.srs"); + ParamsKZG::read(&mut bytes.as_slice()).unwrap() +} + +/// Returns an instance of the AuthDecode circuit. +fn circuit_instance() -> AuthDecodeCircuit { + // We need an instance of the circuit, the exact inputs don't matter. + AuthDecodeCircuit::new([F::default(); FIELD_ELEMENTS], F::default(), F::default()) +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/poseidon.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/poseidon.rs new file mode 100644 index 0000000000..45662b0209 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/poseidon.rs @@ -0,0 +1,69 @@ +use halo2_poseidon::poseidon::{primitives::Spec, Pow5Chip, Pow5Config}; +use halo2_proofs::{halo2curves::bn256::Fr as F, plonk::ConstraintSystem}; + +/// Configures the in-circuit Poseidon for rate 15 and returns the config. +// Patterned after https://github.com/privacy-scaling-explorations/poseidon-gadget/blob/764a682ee448bfbde0cc92a04d241fe738ba2d14/src/poseidon/pow5.rs#L621 +pub fn configure_poseidon_rate_15>( + rate: usize, + meta: &mut ConstraintSystem, +) -> Pow5Config { + let width = rate + 1; + let state = (0..width).map(|_| meta.advice_column()).collect::>(); + let partial_sbox = meta.advice_column(); + + let rc_a = (0..width).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..width).map(|_| meta.fixed_column()).collect::>(); + + Pow5Chip::configure::( + meta, + state.try_into().unwrap(), + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + ) +} + +/// Configures the in-circuit Poseidon for rate 1 and returns the config +// Patterned after https://github.com/privacy-scaling-explorations/poseidon-gadget/blob/764a682ee448bfbde0cc92a04d241fe738ba2d14/src/poseidon/pow5.rs#L621 +#[allow(dead_code)] +pub fn configure_poseidon_rate_1>( + rate: usize, + meta: &mut ConstraintSystem, +) -> Pow5Config { + let width = rate + 1; + let state = (0..width).map(|_| meta.advice_column()).collect::>(); + let partial_sbox = meta.advice_column(); + + let rc_a = (0..width).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..width).map(|_| meta.fixed_column()).collect::>(); + + Pow5Chip::configure::( + meta, + state.try_into().unwrap(), + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + ) +} + +/// Configures the in-circuit Poseidon for rate 2 and returns the config +// Patterned after https://github.com/privacy-scaling-explorations/poseidon-gadget/blob/764a682ee448bfbde0cc92a04d241fe738ba2d14/src/poseidon/pow5.rs#L621 +pub fn configure_poseidon_rate_2>( + rate: usize, + meta: &mut ConstraintSystem, +) -> Pow5Config { + let width = rate + 1; + let state = (0..width).map(|_| meta.advice_column()).collect::>(); + let partial_sbox = meta.advice_column(); + + let rc_a = (0..width).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..width).map(|_| meta.fixed_column()).collect::>(); + + Pow5Chip::configure::( + meta, + state.try_into().unwrap(), + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + ) +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/prover.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/prover.rs new file mode 100644 index 0000000000..5f0760f3e9 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/prover.rs @@ -0,0 +1,392 @@ +use crate::{ + backend::{ + halo2::{ + circuit::{AuthDecodeCircuit, FIELD_ELEMENTS, SALT_SIZE, USABLE_BYTES}, + onetimesetup::proving_key, + utils::bytes_be_to_f, + Bn256F, CHUNK_SIZE, PARAMS, + }, + traits::{Field, ProverBackend as Backend}, + }, + prover::{PrivateInput, ProverError, ProverInput}, + Proof, +}; + +use halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr as F, G1Affine}, + plonk, + plonk::ProvingKey, + poly::kzg::{commitment::KZGCommitmentScheme, multiopen::ProverGWC}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, +}; + +use poseidon_halo2::hash; + +use rand::{thread_rng, Rng}; + +#[cfg(any(test, feature = "fixtures"))] +use std::any::Any; + +#[cfg(feature = "tracing")] +use tracing::{debug, debug_span, instrument, Instrument}; + +use super::prepare_instance; + +/// The Prover of the AuthDecode circuit. +#[derive(Clone)] +pub struct Prover { + /// The proving key. + proving_key: ProvingKey, +} + +impl Backend for Prover { + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all))] + fn commit_plaintext(&self, plaintext: Vec) -> (Bn256F, Bn256F) { + // Generate a random salt and add it to the plaintext. + let mut rng = thread_rng(); + let salt = core::iter::repeat_with(|| rng.gen::()) + .take(SALT_SIZE) + .collect::>(); + let salt = Bn256F::from_bytes_be(salt); + + ( + self.commit_plaintext_with_salt(plaintext, salt.clone()), + salt, + ) + } + + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all))] + fn commit_plaintext_with_salt(&self, plaintext: Vec, salt: Bn256F) -> Bn256F { + assert!(plaintext.len() <= self.chunk_size()); + + // Split up the plaintext bytes into field elements. + let mut plaintext: Vec = plaintext + .chunks(self.usable_bytes()) + .map(|bytes| Bn256F::from_bytes_be(bytes.to_vec())) + .collect::>(); + // Zero-pad the total count of field elements if needed. + plaintext.extend(vec![Bn256F::zero(); FIELD_ELEMENTS - plaintext.len()]); + + plaintext.push(salt); + + hash_internal(&plaintext) + } + + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all))] + fn commit_encoding_sum(&self, encoding_sum: Bn256F) -> (Bn256F, Bn256F) { + // Generate a random salt. + let mut rng = thread_rng(); + let salt = core::iter::repeat_with(|| rng.gen::()) + .take(SALT_SIZE) + .collect::>(); + let salt = Bn256F::from_bytes_be(salt); + + // XXX: we could pack the sum and the salt into a single field element at the cost of performing + // an additional range check in the circuit, but the gains would be negligible. + (hash_internal(&[encoding_sum, salt.clone()]), salt) + } + + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + fn prove(&self, input: Vec>) -> Result, ProverError> { + // XXX: using the default strategy of proving one chunk of plaintext with one proof. + // There are considerable gains to be had when proving multiple chunks with one proof. + + let proofs = input + .into_iter() + .map(|input| { + let instance_columns = prepare_instance(input.public(), self.usable_bytes()); + let circuit = prepare_circuit(input.private(), self.usable_bytes()); + + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + + plonk::create_proof::< + KZGCommitmentScheme, + ProverGWC<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >( + &PARAMS, + &self.proving_key, + &[circuit.clone()], + &[&instance_columns + .iter() + .map(|col| col.as_slice()) + .collect::>()], + &mut thread_rng(), + &mut transcript, + ) + .map_err(|e| ProverError::ProvingBackendError(e.to_string()))?; + + Ok(Proof::new(&transcript.finalize())) + }) + .collect::, ProverError>>()?; + + Ok(proofs) + } + + fn chunk_size(&self) -> usize { + CHUNK_SIZE + } + + #[cfg(any(test, feature = "fixtures"))] + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Default for Prover { + fn default() -> Self { + Self::new() + } +} + +impl Prover { + /// Generates a proving key and creates a new prover. + // + // To prevent the latency caused by the generation of a proving key, consider caching + // the proving key and use `new_with_key` instead. + pub fn new() -> Self { + Self { + proving_key: proving_key(), + } + } + + /// Creates a new prover with the provided proving key. + pub fn new_with_key(proving_key: ProvingKey) -> Self { + Self { proving_key } + } + + /// How many least significant bytes of a field element are used to pack the plaintext into. + fn usable_bytes(&self) -> usize { + USABLE_BYTES + } +} + +/// Prepares an instance of the circuit. +#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all))] +fn prepare_circuit(input: &PrivateInput, usable_bytes: usize) -> AuthDecodeCircuit { + // Split up the plaintext into field elements. + let mut plaintext: Vec = input + .plaintext() + .chunks(usable_bytes) + .map(|bytes| bytes_be_to_f(bytes.to_vec())) + .collect::>(); + // Zero-pad the total count of field elements if needed. + plaintext.extend(vec![F::zero(); FIELD_ELEMENTS - plaintext.len()]); + + AuthDecodeCircuit::new( + plaintext.try_into().unwrap(), + input.plaintext_salt().inner, + input.encoding_sum_salt().inner, + ) +} + +/// Hashes `inputs` with Poseidon and returns the digest. +fn hash_internal(inputs: &[Bn256F]) -> Bn256F { + hash(&inputs.iter().map(|f| f.into()).collect::>()).into() +} + +#[cfg(any(test, feature = "fixtures"))] +/// Wraps `prepare_circuit` to expose it for fixtures. +pub fn _prepare_circuit(input: &PrivateInput, usable_bytes: usize) -> AuthDecodeCircuit { + prepare_circuit(input, usable_bytes) +} + +#[cfg(test)] +// Whether the `test_binary_check_fail` test is running. +pub static mut TEST_BINARY_CHECK_FAIL_IS_RUNNING: bool = false; + +#[cfg(test)] +mod tests { + use crate::{ + backend::halo2::{verifier::Verifier, BITS_PER_LIMB}, + tests::proof_inputs_for_backend, + }; + + use rstest::{fixture, rstest}; + + use super::*; + + use halo2_proofs::dev::{metadata::Constraint, MockProver, VerifyFailure}; + + // Returns the instance columns and the circuit for proof generation. + #[fixture] + #[once] + fn proof_input() -> (Vec>, AuthDecodeCircuit) { + let p = Prover::new(); + let v = Verifier::new(); + let input = proof_inputs_for_backend(p.clone(), v)[0].clone(); + ( + prepare_instance(input.public(), p.usable_bytes()), + prepare_circuit(input.private(), p.usable_bytes()), + ) + } + + #[fixture] + #[once] + fn k() -> u32 { + crate::backend::halo2::fixtures::k() + } + + // Expects verification to succeed when the correct proof generation inputs are used. + #[rstest] + fn test_ok(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_ok()); + } + + // Expects verification to fail when the plaintext is wrong. + #[rstest] + fn test_bad_plaintext(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + // Flip the lowest bit of the first field element. + let bit = proof_input.1.plaintext[0][3][63]; + let new_bit = F::one() - bit; + proof_input.1.plaintext[0][3][63] = new_bit; + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects verification to fail when the plaintext salt is wrong. + #[rstest] + fn test_bad_plaintext_salt(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + proof_input.1.plaintext_salt += F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects verification to fail when the encoding sum salt is wrong. + #[rstest] + fn test_bad_encoding_sum_salt(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + proof_input.1.encoding_sum_salt += F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects verification to fail when a delta is wrong. + #[rstest] + fn test_bad_delta(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + // Note that corrupting the delta corresponding to a bit with the value 0 will not cause a + // verification failure, since the dot product will not be affected by the corruption. + + // Find the index of the plaintext bit with the value 1 in the low limb of the first field + // element. + let mut index: Option = None; + for (idx, bit) in proof_input.1.plaintext[0][3].iter().enumerate() { + if *bit == F::one() { + index = Some(idx); + break; + } + } + + // Corrupt the corresponding delta on the 4th row in the `index`-th column. + proof_input.0[index.unwrap()][3] += F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects verification to fail when the plaintext hash is wrong. + #[rstest] + fn test_bad_plaintext_hash(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + // There are as many instance columns with deltas as there are `BIT_COLUMNS`. + // The value that we need is in the column after the deltas on the first row. + proof_input.0[BITS_PER_LIMB][0] += F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects verification to fail when the encoding sum hash is wrong. + #[rstest] + fn test_bad_encoding_sum_hash(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + // There are as many instance columns with deltas as there are `BIT_COLUMNS`. + // The value that we need is in the column after the deltas on the second row. + proof_input.0[BITS_PER_LIMB][1] += F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects verification to fail when the zero sum is wrong. + #[rstest] + fn test_bad_zero_sum(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + // There are as many instance columns with deltas as there are `BIT_COLUMNS`. + // The value that we need is in the column after the deltas on the third row. + proof_input.0[BITS_PER_LIMB][2] += F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + assert!(prover.verify().is_err()); + } + + // Expects an unsatisfied constraint in the "binary_check" gate when not all bits of the plaintext + // are binary. + #[rstest] + fn test_binary_check_fail(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + unsafe { + TEST_BINARY_CHECK_FAIL_IS_RUNNING = true; + } + + proof_input.1.plaintext[1][2][34] = F::one() + F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + + // We may need to change gate index here if we modify the circuit. + let expected_failed_constraint: Constraint = ((7, "binary_check").into(), 34, "").into(); + + match &prover.verify().err().unwrap()[0] { + VerifyFailure::ConstraintNotSatisfied { + constraint, + location: _, + cell_values: _, + } => assert!(constraint == &expected_failed_constraint), + _ => panic!("An unexpected constraint was unsatisfied"), + } + } + + // Expects an unsatisfied constraint in the "eight_bits_zero" gate when not all of the 8 MSBs of a + // field element are zeroes. + #[rstest] + fn test_eight_bits_zero_fail(proof_input: &(Vec>, AuthDecodeCircuit), k: &u32) { + let mut proof_input: (Vec>, AuthDecodeCircuit) = proof_input.clone(); + + // Set the MSB to 1. + proof_input.1.plaintext[0][0][0] = F::one(); + + let prover = MockProver::run(*k, &proof_input.1, proof_input.0).unwrap(); + + // We may need to change gate index here if we modify the circuit. + let expected_failed_constraint: Constraint = ((13, "eight_bits_zero").into(), 0, "").into(); + + match &prover.verify().err().unwrap()[0] { + VerifyFailure::ConstraintNotSatisfied { + constraint, + location: _, + cell_values: _, + } => assert!(constraint == &expected_failed_constraint), + _ => panic!("An unexpected constraint was unsatisfied"), + } + } +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/utils.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/utils.rs new file mode 100644 index 0000000000..4e8806c0e3 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/utils.rs @@ -0,0 +1,285 @@ +use cfg_if::cfg_if; +use ff::{Field, FromUniformBytes}; +use halo2_proofs::halo2curves::bn256::Fr as F; +use itybity::{FromBitIterator, ToBits}; + +#[cfg(test)] +use crate::backend::halo2::prover::TEST_BINARY_CHECK_FAIL_IS_RUNNING; + +/// Converts big-endian bytes into a field element by reducing by the modulus. +/// +/// # Arguments +/// +/// * `bytes` - The bytes to be converted. +/// +/// # Panics +/// +/// Panics if the count of bytes is > 64. +pub fn bytes_be_to_f(mut bytes: Vec) -> F { + bytes.reverse(); + let mut wide = [0u8; 64]; + wide[0..bytes.len()].copy_from_slice(&bytes); + F::from_uniform_bytes(&wide) +} + +/// Decomposes a field element into 256 bits in MSB-first bit order. +/// +/// # Arguments +/// +/// * `f` - The field element to decompose. +pub fn f_to_bits(f: &F) -> [bool; 256] { + let mut bytes = f.to_bytes(); + // Reverse to get bytes in big-endian. + bytes.reverse(); + // It is safe to `unwrap` since 32 bytes will always convert to 256 bits. + bytes.to_msb0_vec().try_into().unwrap() +} + +/// Converts a slice of `items` into a matrix in column-major order performing the necessary padding. +/// +/// Each chunk of `chunk_size` items will be padded with the default value on the left in order to +/// bring the size of the chunk to `pad_chunk_to_size`. Then a matrix of `row_count` rows and +/// `column_count` columns will be filled with items in row-major order, filling any empty trailing +/// cells with the default value. Finally, the matrix will be transposed. +/// +/// # Arguments +/// +/// * `items` - The items to be arranged into a matrix. +/// * `chunk_size` - The size of a chunk of items that has to be padded to the `pad_chunk_to_size` size. +/// * `pad_chunk_to_size` - The size to which a chunk of items will be padded. +/// * `row_count` - The amount of rows in the resulting matrix. +/// * `column_count` - The amount of columns in the resulting matrix. +/// +/// # Panics +/// +/// Panics if the matrix cannot be created. +pub fn slice_to_columns( + items: &[V], + chunk_size: usize, + pad_chunk_to_size: usize, + row_count: usize, + column_count: usize, +) -> Vec> +where + V: Default + Clone, +{ + let total = row_count * column_count; + assert!(pad_chunk_to_size >= chunk_size); + + // Left-pad each individual chunk. + let mut items = items + .chunks(chunk_size) + .flat_map(|chunk| { + let mut v = vec![V::default(); pad_chunk_to_size - chunk.len()]; + v.extend(chunk.to_vec()); + v + }) + .collect::>(); + + assert!(items.len() <= total); + + // Fill empty cells of the matrix. + items.extend(vec![V::default(); total - items.len()]); + + // Create a row-major matrix. + let items = items + .chunks(column_count) + .map(|c| c.to_vec()) + .collect::>(); + + debug_assert!(items.len() == row_count); + + // Transpose to column-major. + transpose_matrix(items) +} + +/// Composes the 64 `bits` of a limb with the given `index` into a field element, left shifting if +/// needed. `bits` are in MSB-first order. The limb with `index` 0 is the highest limb. +/// +/// # Arguments +/// +/// * `bits` - The bits to be composed. +/// * `index` - The index of a limb to be composed. +/// +/// # Panics +/// +/// Panics if limb index > 3 or if any of the `bits` is not a boolean value. +#[allow(clippy::collapsible_else_if)] +pub fn compose_bits(bits: &[F; 64], index: usize) -> F { + assert!(index < 4); + let bits = bits + .iter() + .map(|bit| { + if *bit == F::zero() { + false + } else if *bit == F::one() { + true + } else { + cfg_if! { + if #[cfg(test)] { + if unsafe{TEST_BINARY_CHECK_FAIL_IS_RUNNING} { + // Don't panic, use an arbitrary valid bit value. + true + } else { + // For all other tests, panic as usual. + panic!("field element is not a boolean value"); + } + } + else { + panic!("field element is not a boolean value"); + } + } + } + }) + .collect::>(); + + let two = F::one() + F::one(); + + // Left-shift. + bits_to_f(&bits) * two.pow([((3 - index as u64) * 64).to_le()]) +} + +/// Transposes a matrix. +/// +/// # Panics +/// +/// Panics if `matrix` is not a rectangular matrix. +fn transpose_matrix(matrix: Vec>) -> Vec> +where + V: Clone, +{ + let len = matrix[0].len(); + matrix[1..].iter().for_each(|row| assert!(row.len() == len)); + + (0..len) + .map(|i| { + matrix + .iter() + .map(|inner| inner[i].clone()) + .collect::>() + }) + .collect::>() +} + +/// Converts bits in MSB-first order into BE bytes. The bits will be internally left-padded +/// with zeroes to the nearest multiple of 8. +fn boolvec_to_u8vec(bv: &[bool]) -> Vec { + // Reverse to lsb0 since `itybity` can only pad the rightmost bits. + let mut b = Vec::::from_lsb0_iter(bv.iter().rev().copied()); + // Reverse to get big endian byte order. + b.reverse(); + b +} + +/// Converts bits in MSB-first order into a field element by reducing by the modulus. +/// +/// # Panics +/// +/// Panics if the count of bits is > 512. +fn bits_to_f(bits: &[bool]) -> F { + bytes_be_to_f(boolvec_to_u8vec(bits)) +} + +#[cfg(test)] +mod tests { + use num::BigUint; + + use super::*; + + #[test] + fn test_bytes_be_to_f() { + assert_eq!(bytes_be_to_f(vec![1u8, 2u8]), F::from(258u64)); + } + + #[test] + fn test_f_to_bits() { + let mut bits = vec![false; 246]; + bits.extend([ + // 01 0000 0100 == 260 + false, true, false, false, false, false, false, true, false, false, + ]); + let expected: [bool; 256] = bits.try_into().unwrap(); + assert_eq!(f_to_bits(&F::from(260u64)), expected); + } + + #[test] + fn test_slice_to_columns() { + let slice = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + // First the matrix will be padded and chunked. + // It will look like this in row-major order: + // 0 0 1 2 + // 3 0 0 4 + // 5 6 0 0 + // 7 8 9 0 + // 0 0 0 10 + // 0 0 0 0 + // Then it will be transposed to column-major order: + let expected1 = vec![ + vec![0, 3, 5, 7, 0, 0], + vec![0, 0, 6, 8, 0, 0], + vec![1, 0, 0, 9, 0, 0], + vec![2, 4, 0, 0, 10, 0], + ]; + let expected2 = vec![ + vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + vec![0, 1, 0, 4, 0, 7, 0, 0, 0, 0], + vec![0, 2, 0, 5, 0, 8, 0, 0, 0, 0], + vec![0, 3, 0, 6, 0, 9, 0, 10, 0, 0], + ]; + assert_eq!(slice_to_columns(&slice, 3, 5, 6, 4), expected1); + assert_eq!(slice_to_columns(&slice, 3, 8, 10, 4), expected2); + } + + #[test] + fn test_compose_bits() { + let two = BigUint::from(2u128); + let mut bits: [F; 64] = (0..64) + .map(|_| F::zero()) + .collect::>() + .try_into() + .unwrap(); + + for (i, expected) in (0..4).zip([1, 3, 7, 15]) { + // On each iteration, set one more LSB. + bits[63 - i] = F::one(); + assert_eq!( + compose_bits(&bits, 3 - i), + bytes_be_to_f( + (BigUint::from(expected as u32) * two.pow(64 * i as u32)).to_bytes_be() + ) + ); + } + } + + #[test] + fn test_transpose_matrix() { + let matrix = vec![ + vec![1, 2, 3], + vec![4, 5, 6], + vec![7, 8, 9], + vec![10, 11, 12], + ]; + + let expected = vec![vec![1, 4, 7, 10], vec![2, 5, 8, 11], vec![3, 6, 9, 12]]; + assert_eq!(transpose_matrix(matrix), expected); + } + + #[test] + fn test_boolvec_to_u8vec() { + let bits = [true, false]; + assert_eq!(boolvec_to_u8vec(&bits), [2]); + + let bits = [true, false, false, false, false, false, false, true, true]; + assert_eq!(boolvec_to_u8vec(&bits), [1, 3]); + } + + #[test] + fn test_bits_to_f() { + // 01 0000 0011 == 259 + let bits = [ + false, true, false, false, false, false, false, false, true, true, + ]; + assert_eq!(bits_to_f(&bits), F::from(259u64)); + } +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/halo2/verifier.rs b/crates/components/authdecode/authdecode-core/src/backend/halo2/verifier.rs new file mode 100644 index 0000000000..89d5573372 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/halo2/verifier.rs @@ -0,0 +1,134 @@ +use crate::{ + backend::{ + halo2::{circuit::USABLE_BYTES, Bn256F, CHUNK_SIZE, PARAMS}, + traits::VerifierBackend as Backend, + }, + verifier::VerifierError, + Proof, PublicInput, +}; + +use ff::{FromUniformBytes, WithSmallOrderMulGroup}; +use halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr as F, G1Affine}, + plonk::{verify_proof as verify_plonk_proof, VerifyingKey}, + poly::{ + commitment::{CommitmentScheme, Verifier as CommitmentVerifier}, + kzg::{ + commitment::KZGCommitmentScheme, multiopen::VerifierGWC, strategy::AccumulatorStrategy, + }, + VerificationStrategy, + }, + transcript::{Blake2bRead, Challenge255, EncodedChallenge, TranscriptReadBuffer}, +}; + +#[cfg(feature = "tracing")] +use tracing::{debug, debug_span, instrument, Instrument}; + +use super::{onetimesetup::verification_key, prepare_instance}; + +/// The Verifier of the authdecode circuit. +pub struct Verifier { + /// The verification key. + verification_key: VerifyingKey, +} + +impl Default for Verifier { + fn default() -> Self { + Self::new() + } +} + +impl Verifier { + /// Generates a verification key and creates a new verifier. + // + // To prevent the latency caused by the generation of a verification key, consider caching + // the verification key and use `new_with_key` instead. + pub fn new() -> Self { + Self { + verification_key: verification_key(), + } + } + + /// Creates a new verifier with the provided key. + pub fn new_with_key(verification_key: VerifyingKey) -> Self { + Self { verification_key } + } + + /// How many least significant bytes of a field element are used to pack the plaintext into. + fn usable_bytes(&self) -> usize { + USABLE_BYTES + } +} + +impl Backend for Verifier { + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + fn verify( + &self, + inputs: Vec>, + proofs: Vec, + ) -> Result<(), VerifierError> { + // XXX: using the default strategy of "one proof proves one chunk of plaintext". + if inputs.len() != proofs.len() { + return Err(VerifierError::WrongProofCount(inputs.len(), proofs.len())); + } + + for (input, proof) in inputs.into_iter().zip(proofs) { + let instance_columns = prepare_instance(&input, self.usable_bytes()); + + verify_proof::< + KZGCommitmentScheme, + VerifierGWC<'_, Bn256>, + _, + Blake2bRead<_, _, Challenge255<_>>, + AccumulatorStrategy<_>, + >( + &PARAMS, + &self.verification_key, + &proof.0, + &[&instance_columns + .iter() + .map(|col| col.as_slice()) + .collect::>()], + )?; + } + + Ok(()) + } + + fn chunk_size(&self) -> usize { + CHUNK_SIZE + } +} + +#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] +fn verify_proof< + 'a, + 'params, + Scheme: CommitmentScheme, + V: CommitmentVerifier<'params, Scheme>, + E: EncodedChallenge, + T: TranscriptReadBuffer<&'a [u8], Scheme::Curve, E>, + Strategy: VerificationStrategy<'params, Scheme, V, Output = Strategy>, +>( + params_verifier: &'params Scheme::ParamsVerifier, + vk: &VerifyingKey, + proof: &'a [u8], + instances: &[&[&[F]]], +) -> Result<(), VerifierError> +where + Scheme::Scalar: Ord + WithSmallOrderMulGroup<3> + FromUniformBytes<64>, +{ + let mut transcript = T::init(proof); + + let strategy = Strategy::new(params_verifier); + let strategy = verify_plonk_proof(params_verifier, vk, strategy, instances, &mut transcript) + .map_err(|e| VerifierError::VerificationFailed(e.to_string()))?; + + if !strategy.finalize() { + return Err(VerifierError::VerificationFailed( + "VerificationStrategy::finalize() returned false".to_string(), + )); + } + + Ok(()) +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/mock/circuit.rs b/crates/components/authdecode/authdecode-core/src/backend/mock/circuit.rs new file mode 100644 index 0000000000..54ea64f3f2 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/mock/circuit.rs @@ -0,0 +1,61 @@ +use crate::backend::{mock::prover::hash, traits::Field}; +use itybity::ToBits; + +use super::MockField; + +/// Checks in the clear that the given inputs satisfy all constraints of the AuthDecode circuit. +pub fn is_circuit_satisfied( + plaintext_hash: MockField, + encoding_sum_hash: MockField, + zero_sum: MockField, + deltas: Vec, + mut plaintext: Vec, + plaintext_salt: MockField, + encoding_sum_salt: MockField, +) -> bool { + assert!(plaintext.len() * 8 == deltas.len()); + // Compute dot product of plaintext and deltas. + let dot_product = plaintext.to_msb0_vec().into_iter().zip(deltas).fold( + MockField::zero(), + |acc, (bit, delta)| { + let product = if bit { delta } else { MockField::zero() }; + acc + product + }, + ); + + // Compute encoding sum, add salt, hash it and compare to the expected hash. + let encoding_sum = zero_sum + dot_product; + let mut enc_sum = encoding_sum.to_bytes_be(); + + // Convert salt into bytes padding the most significant bytes if needed. + let salt_bytes = encoding_sum_salt.to_bytes_be(); + let mut salt = [0u8; 16]; + salt[16 - salt_bytes.len()..].copy_from_slice(&salt_bytes); + enc_sum.extend(salt); + + let hash_bytes = hash(&enc_sum); + + let digest = MockField::from_bytes_be(hash_bytes.to_vec()); + + if digest != encoding_sum_hash { + return false; + } + + // Convert salt into bytes padding the most significant bytes if needed. + let salt_bytes = plaintext_salt.to_bytes_be(); + let mut salt = [0u8; 16]; + salt[16 - salt_bytes.len()..].copy_from_slice(&salt_bytes); + + // Add salt to plaintext, hash it and compare to the expected hash. + plaintext.extend(salt); + + let hash_bytes = hash(&plaintext); + + let digest = MockField::from_bytes_be(hash_bytes.to_vec()); + + if digest != plaintext_hash { + return false; + } + + true +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/mock/mod.rs b/crates/components/authdecode/authdecode-core/src/backend/mock/mod.rs new file mode 100644 index 0000000000..ff7c02b772 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/mock/mod.rs @@ -0,0 +1,122 @@ +use crate::backend::traits::Field; +use bincode; +use num::{bigint::Sign, BigInt}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::ops::{Add, Sub}; + +pub mod circuit; +pub mod prover; +pub mod verifier; + +pub use prover::MockProverBackend; +pub use verifier::MockVerifierBackend; + +/// Chunk size in bytes. +pub(crate) const CHUNK_SIZE: usize = 300; + +/// A mock field element. +#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] +pub struct MockField { + #[serde( + serialize_with = "bigint_serialize", + deserialize_with = "bigint_deserialize" + )] + inner: BigInt, +} + +impl Add for MockField { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + inner: self.inner + rhs.inner, + } + } +} + +impl Sub for MockField { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self { + inner: self.inner - rhs.inner, + } + } +} + +impl Field for MockField { + fn from_bytes_be(bytes: Vec) -> Self { + Self { + inner: BigInt::from_bytes_be(Sign::Plus, &bytes), + } + } + + fn to_bytes_be(self) -> Vec { + let (_, bytes) = self.inner.to_bytes_be(); + bytes + } + + fn zero() -> Self { + Self { + inner: BigInt::from(0u8), + } + } +} + +/// A mock proof. +/// +/// Normally, the prover proves in zk the knowledge of private inputs which satisfy the circuit's +/// constraints. Here the private inputs are simply revealed without zk. +#[derive(Serialize, Deserialize)] +pub struct MockProof { + plaintext: Vec, + plaintext_salt: MockField, + encoding_sum_salt: MockField, +} + +impl MockProof { + /// Creates a new mock proof. + pub fn new( + plaintext: Vec, + plaintext_salt: MockField, + encoding_sum_salt: MockField, + ) -> Self { + Self { + plaintext, + plaintext_salt, + encoding_sum_salt, + } + } + + /// Serializes `self` into bytes. + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap() + } + + /// Deserializes `self` from bytes + pub fn from_bytes(bytes: Vec) -> Self { + bincode::deserialize(&bytes).unwrap() + } +} + +fn bigint_serialize(bigint: &BigInt, serializer: S) -> Result +where + S: Serializer, +{ + let (sign, bytes) = bigint.to_bytes_be(); + assert!(sign == Sign::Plus); + serializer.serialize_bytes(&bytes) +} + +fn bigint_deserialize<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let bytes: Vec = Vec::deserialize(deserializer)?; + Ok(BigInt::from_bytes_be(Sign::Plus, &bytes)) +} + +/// Returns a pair of mock backends. +pub fn backend_pair() -> (MockProverBackend, MockVerifierBackend) { + (MockProverBackend::new(), MockVerifierBackend::new()) +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/mock/prover.rs b/crates/components/authdecode/authdecode-core/src/backend/mock/prover.rs new file mode 100644 index 0000000000..4dba5d4caa --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/mock/prover.rs @@ -0,0 +1,91 @@ +use crate::{ + backend::{ + mock::{MockField, MockProof, CHUNK_SIZE}, + traits::{Field, ProverBackend}, + }, + prover::{ProverError, ProverInput}, + Proof, +}; + +use rand::{thread_rng, Rng}; + +#[cfg(any(test, feature = "fixtures"))] +use std::any::Any; + +/// A mock prover backend. +#[derive(Default)] +pub struct MockProverBackend {} + +impl MockProverBackend { + pub fn new() -> Self { + Self {} + } +} + +impl ProverBackend for MockProverBackend { + fn commit_plaintext(&self, mut plaintext: Vec) -> (MockField, MockField) { + assert!(plaintext.len() <= self.chunk_size()); + + // Add random salt to plaintext and hash it. + let salt: [u8; 16] = thread_rng().gen(); + plaintext.extend(salt); + + let hash_bytes = &hash(&plaintext); + + ( + MockField::from_bytes_be(hash_bytes.to_vec()), + MockField::from_bytes_be(salt.to_vec()), + ) + } + + fn commit_plaintext_with_salt(&self, _plaintext: Vec, _salt: MockField) -> MockField { + unimplemented!() + } + + fn commit_encoding_sum(&self, encoding_sum: MockField) -> (MockField, MockField) { + // Add random salt to encoding_sum and hash it. + let salt: [u8; 16] = thread_rng().gen(); + + let mut enc_sum = encoding_sum.to_bytes_be(); + enc_sum.extend(salt); + + let hash_bytes = hash(&enc_sum); + + ( + MockField::from_bytes_be(hash_bytes.to_vec()), + MockField::from_bytes_be(salt.to_vec()), + ) + } + + fn chunk_size(&self) -> usize { + CHUNK_SIZE + } + + fn prove(&self, input: Vec>) -> Result, ProverError> { + // Use the default strategy of one proof for one chunk. + Ok(input + .into_iter() + .map(|input| { + Proof::new( + &MockProof::new( + input.private().plaintext().clone(), + input.private().plaintext_salt().clone(), + input.private().encoding_sum_salt().clone(), + ) + .to_bytes(), + ) + }) + .collect::>()) + } + + #[cfg(any(test, feature = "fixtures"))] + fn as_any(&self) -> &dyn Any { + self + } +} + +pub fn hash(bytes: &[u8]) -> [u8; 32] { + let mut hasher = blake3::Hasher::new(); + hasher.update(bytes); + hasher.finalize().into() +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/mock/verifier.rs b/crates/components/authdecode/authdecode-core/src/backend/mock/verifier.rs new file mode 100644 index 0000000000..3633edc481 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/mock/verifier.rs @@ -0,0 +1,54 @@ +use crate::{ + backend::{ + mock::{circuit::is_circuit_satisfied, MockField, MockProof, CHUNK_SIZE}, + traits::VerifierBackend as Backend, + }, + verifier::VerifierError, + Proof, PublicInput, +}; + +/// A mock verifier backend. +#[derive(Default)] +pub struct MockVerifierBackend {} + +impl MockVerifierBackend { + pub fn new() -> Self { + Self {} + } +} + +impl Backend for MockVerifierBackend { + fn verify( + &self, + inputs: Vec>, + proofs: Vec, + ) -> Result<(), VerifierError> { + // Using the default strategy of one proof for one chunk. + if inputs.len() != proofs.len() { + return Err(VerifierError::WrongProofCount(inputs.len(), proofs.len())); + } + + for (proof, input) in proofs.iter().zip(inputs) { + let proof = MockProof::from_bytes(proof.0.to_vec()); + if !is_circuit_satisfied( + input.plaintext_hash, + input.encoding_sum_hash, + input.zero_sum, + input.deltas, + proof.plaintext, + proof.plaintext_salt, + proof.encoding_sum_salt, + ) { + return Err(VerifierError::VerificationFailed( + "Mock circuit was not satisfied".to_string(), + )); + }; + } + + Ok(()) + } + + fn chunk_size(&self) -> usize { + CHUNK_SIZE + } +} diff --git a/crates/components/authdecode/authdecode-core/src/backend/mod.rs b/crates/components/authdecode/authdecode-core/src/backend/mod.rs new file mode 100644 index 0000000000..42483ea021 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/mod.rs @@ -0,0 +1,5 @@ +pub mod halo2; +pub mod traits; + +#[cfg(any(test, feature = "mock"))] +pub mod mock; diff --git a/crates/components/authdecode/authdecode-core/src/backend/traits.rs b/crates/components/authdecode/authdecode-core/src/backend/traits.rs new file mode 100644 index 0000000000..0cd9a2efa5 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/backend/traits.rs @@ -0,0 +1,98 @@ +//! Traits for the prover backend and the verifier backend. + +use crate::{ + prover::{ProverError, ProverInput}, + verifier::VerifierError, + Proof, PublicInput, +}; + +#[cfg(any(test, feature = "fixtures"))] +use std::any::Any; + +/// A trait for zk proof generation backend. +pub trait ProverBackend +where + F: Field, +{ + /// Creates a commitment to the plaintext, padding the plaintext if necessary. + /// + /// Returns the commitment and the salt used to create the commitment. + /// + /// # Panics + /// + /// Panics if the length of the plaintext exceeds the allowed maximum. + /// + /// # Arguments + /// + /// * `plaintext` - The plaintext to commit to. + fn commit_plaintext(&self, plaintext: Vec) -> (F, F); + + /// Creates a commitment to the plaintext with the provided salt, padding the plaintext if + /// necessary. + /// + /// Returns the commitment. + /// + /// # Panics + /// + /// Panics if the length of the plaintext exceeds the allowed maximum. + /// + /// # Arguments + /// + /// * `plaintext` - The plaintext to commit to. + /// * `salt` - The salt of the commitment. + fn commit_plaintext_with_salt(&self, plaintext: Vec, salt: F) -> F; + + /// Creates a commitment to the encoding sum. + /// + /// Returns the commitment and the salt used to create the commitment. + /// + /// # Arguments + /// + /// * `encoding_sum` - The sum of the encodings to commit to. + fn commit_encoding_sum(&self, encoding_sum: F) -> (F, F); + + /// Given the `inputs` to the AuthDecode circuit, generates and returns `Proof`(s). + /// + /// # Arguments + /// + /// * `inputs` - A collection of circuit inputs. Each input proves a single chunk + /// of plaintext. + fn prove(&self, inputs: Vec>) -> Result, ProverError>; + + /// The bytesize of a single chunk of plaintext. Does not include the salt. + fn chunk_size(&self) -> usize; + + // Testing only. Used to downcast to a concrete type. + #[cfg(any(test, feature = "fixtures"))] + fn as_any(&self) -> &dyn Any; +} + +/// A trait for zk proof verification backend. +pub trait VerifierBackend: Send +where + F: Field, +{ + /// Verifies multiple inputs against multiple proofs. + /// + /// The backend internally determines which inputs correspond to which proofs. + fn verify(&self, inputs: Vec>, proofs: Vec) -> Result<(), VerifierError>; + + /// The bytesize of a single chunk of plaintext. Does not include the salt. + fn chunk_size(&self) -> usize; +} + +/// Methods for working with a field element. +pub trait Field { + /// Creates a new field element from bytes in big-endian byte order. + fn from_bytes_be(bytes: Vec) -> Self + where + Self: Sized; + + /// Returns the field element as bytes in big-endian byte order. + fn to_bytes_be(self) -> Vec; + + /// Returns zero, the additive identity. + fn zero() -> Self + where + Self: Sized; +} diff --git a/crates/components/authdecode/authdecode-core/src/encodings/active.rs b/crates/components/authdecode/authdecode-core/src/encodings/active.rs new file mode 100644 index 0000000000..0a5aacd2af --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/encodings/active.rs @@ -0,0 +1,167 @@ +use crate::{backend::traits::Field, encodings::Encoding, id::IdCollection}; + +use getset::Getters; +use itybity::FromBitIterator; + +/// A non-empty collection of active bit encodings with the associated plaintext value. +#[derive(Clone, PartialEq, Debug, Getters, Default)] +pub struct ActiveEncodings { + /// The encoding for each bit of the plaintext in MSB0 bit order. + #[getset(get = "pub")] + encodings: Vec, + /// A collection of ids of each bit of the encoded plaintext. + /// + /// This type will not enforce that when there are duplicate ids in the collection, the values of + /// the corresponding encodings must match. + #[getset(get = "pub")] + ids: I, +} + +impl ActiveEncodings +where + I: IdCollection, +{ + /// Creates a new collection of active encodings. + /// + /// # Arguments + /// + /// * `encodings` - The active encodings. + /// * `ids` - The collection of ids of each bit of the encoded plaintext. + /// + /// # Panics + /// + /// Panics if either `encodings` or `ids` is empty. + pub fn new(encodings: Vec, ids: I) -> Self { + assert!(!encodings.is_empty() && !ids.is_empty()); + + Self { encodings, ids } + } + + /// Creates a new collection from an iterator. + /// + /// # Arguments + /// + /// * `iter` - The iterator from which to create the collection. + pub fn new_from_iter>(iter: It) -> Self { + let (encodings, ids): (Vec<_>, Vec<_>) = + iter.into_iter().map(|e| (e.encodings, e.ids)).unzip(); + + Self { + encodings: encodings.into_iter().flatten().collect(), + ids: I::new_from_iter(ids), + } + } + + /// Convert `self` into an iterator over chunks of the collection. If `chunk_size` does not divide + /// the length of the collection, then the last chunk will not have length `chunk_size`. + /// + /// # Arguments + /// + /// * `chunk_size` - The size of a chunk. + pub fn into_chunks(self, chunk_size: usize) -> ActiveEncodingsChunks { + ActiveEncodingsChunks { + chunk_size, + encodings: self.encodings.into_iter(), + ids: self.ids, + } + } + + #[allow(clippy::len_without_is_empty)] + /// Returns the number of active encodings. + pub fn len(&self) -> usize { + self.encodings.len() + } + + /// Returns the plaintext encoded by this collection. + pub fn plaintext(&self) -> Vec { + Vec::::from_msb0_iter(self.encodings.iter().map(|enc| *enc.bit())) + } +} + +impl ActiveEncodings +where + I: IdCollection, +{ + /// Computes the arithmetic sum of the encodings. + pub fn compute_sum(&self) -> F + where + F: Field + std::ops::Add, + { + self.encodings.iter().fold(F::zero(), |acc, x| -> F { + acc + F::from_bytes_be(x.value().to_vec()) + }) + } +} + +pub struct ActiveEncodingsChunks { + chunk_size: usize, + encodings: as IntoIterator>::IntoIter, + ids: I, +} + +impl Iterator for ActiveEncodingsChunks +where + I: IdCollection, +{ + type Item = ActiveEncodings; + + fn next(&mut self) -> Option { + if self.encodings.len() == 0 { + None + } else { + Some(ActiveEncodings { + encodings: self + .encodings + .by_ref() + .take(self.chunk_size) + .collect::>(), + ids: self.ids.drain_front(self.chunk_size), + }) + } + } +} + +#[cfg(test)] +mod tests { + + use rand::SeedableRng; + use rand_chacha::ChaCha12Rng; + + use crate::{ + encodings::Encoding, + mock::{Direction, MockBitIds}, + }; + + use super::*; + + // Tests that chunking of active encodings works correctly. + #[allow(clippy::single_range_in_vec_init)] + #[test] + fn test_active_encodings_chunks() { + const BYTE_COUNT: usize = 22; + const CHUNK_BYTESIZE: usize = 14; + + let mut rng = ChaCha12Rng::from_seed([0; 32]); + let all_encodings = [Encoding::random(&mut rng); BYTE_COUNT * 8].to_vec(); + + let ids = MockBitIds::new(Direction::Sent, &[0..BYTE_COUNT]); + let active = ActiveEncodings::new(all_encodings.clone(), ids); + + let mut chunk_iter = active.into_chunks(CHUNK_BYTESIZE * 8); + + // The first chunk will contain encodings for `CHUNK_BYTESIZE` bytes. + let expected_chunk1_encodings = ActiveEncodings::new( + all_encodings[0..CHUNK_BYTESIZE * 8].to_vec(), + MockBitIds::new(Direction::Sent, &[0..CHUNK_BYTESIZE]), + ); + + // The second chunk will contain encodings for `BYTE_COUNT - CHUNK_BYTESIZE` bytes. + let expected_chunk2_encodings = ActiveEncodings::new( + all_encodings[CHUNK_BYTESIZE * 8..BYTE_COUNT * 8].to_vec(), + MockBitIds::new(Direction::Sent, &[CHUNK_BYTESIZE..BYTE_COUNT]), + ); + + assert_eq!(chunk_iter.next().unwrap(), expected_chunk1_encodings); + assert_eq!(chunk_iter.next().unwrap(), expected_chunk2_encodings); + } +} diff --git a/crates/components/authdecode/authdecode-core/src/encodings/encoding.rs b/crates/components/authdecode/authdecode-core/src/encodings/encoding.rs new file mode 100644 index 0000000000..37119f2118 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/encodings/encoding.rs @@ -0,0 +1,38 @@ +use crate::SSP; + +use getset::Getters; + +#[cfg(test)] +use rand::Rng; +#[cfg(test)] +use rand_core::CryptoRng; + +/// An encoding of either the 0 or the 1 value of a bit. +#[derive(Clone, PartialEq, Debug, Default, Copy, Getters)] +pub struct Encoding { + /// The value of the encoding represented as big-endian bytes. + #[getset(get = "pub")] + value: [u8; SSP / 8], + /// The value of the bit that the encoding encodes. + #[getset(get = "pub")] + bit: bool, +} + +impl Encoding { + /// Creates a new instance. + pub fn new(value: [u8; SSP / 8], bit: bool) -> Self { + Self { value, bit } + } + + #[cfg(test)] + /// Returns a random encoding using the provided RNG. + pub fn random(rng: &mut R) -> Self { + Self::new(rng.gen(), rng.gen()) + } + + #[cfg(test)] + /// Sets the value of the bit that the encoding encodes. + pub fn set_bit(&mut self, bit: bool) { + self.bit = bit; + } +} diff --git a/crates/components/authdecode/authdecode-core/src/encodings/full.rs b/crates/components/authdecode/authdecode-core/src/encodings/full.rs new file mode 100644 index 0000000000..7b64524329 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/encodings/full.rs @@ -0,0 +1,191 @@ +use crate::{backend::traits::Field, encodings::Encoding, id::IdCollection}; + +use getset::Getters; + +/// A non-empty collection of full encodings. Each item in the collection is the encodings of the 0 +/// and 1 values of a bit. +#[derive(Clone, PartialEq, Default, Debug, Getters)] +pub struct FullEncodings { + /// Full encodings for each bit. + encodings: Vec<[Encoding; 2]>, + /// The id of each bit encoded by the encodings of this collection. + /// + /// This type will not enforce that when there are duplicate ids in the collection, the values of + /// the corresponding encodings must match. + #[getset(get = "pub")] + ids: I, +} + +impl FullEncodings +where + I: IdCollection, +{ + /// Creates a new collection of full encodings. + /// + /// # Arguments + /// + /// * `encodings` - The pairs of encodings. + /// * `ids` - The collection of ids of each bit of the encoded plaintext. + /// + /// # Panics + /// + /// Panics if either `encodings` or `ids` is empty. + pub fn new(encodings: Vec<[Encoding; 2]>, ids: I) -> Self { + assert!(!encodings.is_empty() && !ids.is_empty()); + + for pair in encodings.clone() { + assert!(!pair[0].bit() && *pair[1].bit()); + } + + Self { encodings, ids } + } + + /// Returns the number of full encodings. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.encodings.len() + } + + /// Convert `self` into an iterator over chunks of the collection. If `chunk_size` does not divide + /// the length of the collection, then the last chunk will not have length `chunk_size`. + /// + /// # Arguments + /// + /// * `chunk_size` - The size of a chunk. + pub fn into_chunks(self, chunk_size: usize) -> FullEncodingsChunks { + FullEncodingsChunks { + chunk_size, + encodings: self.encodings.into_iter(), + ids: self.ids, + } + } + + /// Drains `count` encodings from the front of the collection. + /// + /// # Arguments + /// + /// * `count` - The amount of encodings to drain. + /// + /// # Panics + /// + /// Panics if the collection contains less than `count` encodings. + pub fn drain_front(&mut self, count: usize) -> Self { + let drained = self.encodings.drain(0..count).collect::>(); + assert!(drained.len() == count); + + Self { + encodings: drained, + ids: self.ids.drain_front(count), + } + } + + /// Computes the arithmetic sum of the encodings of the bit value 0. + pub fn compute_zero_sum(&self) -> F + where + F: Field + std::ops::Add, + { + self.encodings.iter().fold(F::zero(), |acc, x| { + acc + F::from_bytes_be(x[0].value().to_vec()) + }) + } + + /// Computes the arithmetic difference between the encoding of the bit value 1 and the encoding + /// of the bit value 0 for each pair in the collection. + pub fn compute_deltas(&self) -> Vec + where + F: Field + std::ops::Sub, + { + self.encodings + .iter() + .map(|pair| { + let a = F::from_bytes_be(pair[1].value().to_vec()); + let b = F::from_bytes_be(pair[0].value().to_vec()); + a - b + }) + .collect() + } + + #[cfg(any(test, feature = "mock"))] + /// Returns full encodings for each bit. + pub fn encodings(&self) -> &[[Encoding; 2]] { + &self.encodings + } +} + +pub struct FullEncodingsChunks { + chunk_size: usize, + encodings: as IntoIterator>::IntoIter, + ids: I, +} + +impl Iterator for FullEncodingsChunks +where + I: IdCollection, +{ + type Item = FullEncodings; + + fn next(&mut self) -> Option { + if self.encodings.len() == 0 { + None + } else { + Some(FullEncodings { + encodings: self + .encodings + .by_ref() + .take(self.chunk_size) + .collect::>(), + ids: self.ids.drain_front(self.chunk_size), + }) + } + } +} + +#[cfg(test)] +mod tests { + use rand::SeedableRng; + use rand_chacha::ChaCha12Rng; + + use crate::{ + encodings::{Encoding, FullEncodings}, + mock::{Direction, MockBitIds}, + }; + + // Tests that chunking of full encodings works correctly. + #[allow(clippy::single_range_in_vec_init)] + #[test] + fn test_full_encodings_chunks() { + const BYTE_COUNT: usize = 22; + const CHUNK_BYTESIZE: usize = 14; + + let mut rng = ChaCha12Rng::from_seed([0; 32]); + let all_encodings = (0..BYTE_COUNT * 8) + .map(|_| { + let mut pair = [Encoding::random(&mut rng); 2]; + // Set the correct bit values. + pair[0].set_bit(false); + pair[1].set_bit(true); + pair + }) + .collect::>(); + + let ids = MockBitIds::new(Direction::Sent, &[0..BYTE_COUNT]); + let full = FullEncodings::new(all_encodings.clone(), ids); + + let mut chunk_iter = full.into_chunks(CHUNK_BYTESIZE * 8); + + // The first chunk will contain encodings for `CHUNK_BYTESIZE` bytes. + let expected_chunk1_encodings = FullEncodings::new( + all_encodings[0..CHUNK_BYTESIZE * 8].to_vec(), + MockBitIds::new(Direction::Sent, &[0..CHUNK_BYTESIZE]), + ); + + // The second chunk will contain encodings for `BYTE_COUNT - CHUNK_BYTESIZE` bytes. + let expected_chunk2_encodings = FullEncodings::new( + all_encodings[CHUNK_BYTESIZE * 8..BYTE_COUNT * 8].to_vec(), + MockBitIds::new(Direction::Sent, &[CHUNK_BYTESIZE..BYTE_COUNT]), + ); + + assert_eq!(chunk_iter.next().unwrap(), expected_chunk1_encodings); + assert_eq!(chunk_iter.next().unwrap(), expected_chunk2_encodings); + } +} diff --git a/crates/components/authdecode/authdecode-core/src/encodings/mod.rs b/crates/components/authdecode/authdecode-core/src/encodings/mod.rs new file mode 100644 index 0000000000..ad44de0999 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/encodings/mod.rs @@ -0,0 +1,24 @@ +pub(crate) mod active; +mod encoding; +mod full; + +pub use active::ActiveEncodings; +pub use encoding::Encoding; +pub use full::FullEncodings; + +use crate::id::IdCollection; + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum EncodingProviderError { + #[error("Unable to provide an encoding with the given id {0}")] + EncodingWithIdNotAvailable(usize), +} + +/// A provider of full encodings of bits identified by their id. +pub trait EncodingProvider +where + I: IdCollection, +{ + /// Returns full encodings for the given bit ids. + fn get_by_ids(&self, ids: &I) -> Result, EncodingProviderError>; +} diff --git a/crates/components/authdecode/authdecode-core/src/fixtures.rs b/crates/components/authdecode/authdecode-core/src/fixtures.rs new file mode 100644 index 0000000000..7a9ded0d40 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/fixtures.rs @@ -0,0 +1,85 @@ +use crate::{ + encodings::{Encoding, FullEncodings}, + mock::{Direction, MockBitIds, MockEncodingProvider}, + prover::CommitmentData, + SSP, +}; +use itybity::ToBits; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha12Rng; + +// The size of plaintext in bytes; +#[allow(dead_code)] +const PLAINTEXT_SIZE: usize = 1000; + +pub fn commitment_data() -> Vec> { + let mut rng = ChaCha12Rng::from_seed([0; 32]); + + // Generate random plaintext. + let plaintext: Vec = core::iter::repeat_with(|| rng.gen::()) + .take(PLAINTEXT_SIZE) + .collect(); + + // Generate the Verifier's full encodings for each bit of the plaintext. + let full_encodings = full_encodings(PLAINTEXT_SIZE * 8); + + // Prover's active encodings are based on their choice bits. + let active_encodings = choose(&full_encodings, &plaintext.to_msb0_vec()); + + // Prover creates two commitments: to the front and to the tail portions of the plaintext. + // Some middle bits of the plaintext will not be committed to. + let range1 = 0..PLAINTEXT_SIZE / 2 - 10; + let range2 = PLAINTEXT_SIZE / 2..PLAINTEXT_SIZE; + let bitrange1 = range1.start * 8..range1.end * 8; + let bitrange2 = range2.start * 8..range2.end * 8; + + let bit_ids1 = MockBitIds::new(Direction::Sent, &[range1.clone()]); + let bit_ids2 = MockBitIds::new(Direction::Sent, &[range2.clone()]); + + let commitment1 = CommitmentData::new( + &plaintext[range1.clone()], + &active_encodings[bitrange1], + bit_ids1, + ); + let commitment2 = CommitmentData::new( + &plaintext[range2.clone()], + &active_encodings[bitrange2], + bit_ids2, + ); + + vec![commitment1, commitment2] +} + +pub fn encoding_provider() -> MockEncodingProvider { + #[allow(clippy::single_range_in_vec_init)] + let bit_ids = MockBitIds::new(Direction::Sent, &[0..PLAINTEXT_SIZE]); + + let full_encodings = full_encodings(PLAINTEXT_SIZE * 8) + .iter() + .map(|e| [Encoding::new(e[0], false), Encoding::new(e[1], true)]) + .collect::>(); + + MockEncodingProvider::new(FullEncodings::new(full_encodings, bit_ids)) +} + +/// Returns random full encodings for `len` bits. +fn full_encodings(len: usize) -> Vec<[[u8; 5]; 2]> { + let mut rng = ChaCha12Rng::from_seed([1; 32]); + + // Generate Verifier's full encodings for each bit of the plaintext. + let mut full_encodings = vec![[[0u8; SSP / 8]; 2]; len]; + for elem in full_encodings.iter_mut() { + *elem = rng.gen(); + } + full_encodings +} + +/// Unzips a slice of pairs, returning items corresponding to choice. +pub fn choose(items: &[[T; 2]], choice: &[bool]) -> Vec { + assert!(items.len() == choice.len(), "arrays are different length"); + items + .iter() + .zip(choice) + .map(|(items, choice)| items[*choice as usize].clone()) + .collect() +} diff --git a/crates/components/authdecode/authdecode-core/src/id.rs b/crates/components/authdecode/authdecode-core/src/id.rs new file mode 100644 index 0000000000..23e067b5f6 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/id.rs @@ -0,0 +1,45 @@ +/// A unique identifier. +#[derive(Default, Clone, PartialEq, Eq, Hash, Debug)] +pub struct Id(pub u64); + +/// A trait for working with a collection of ids. +/// +/// It is permissible for the collection to contain duplicate ids. +#[allow(clippy::len_without_is_empty)] +pub trait IdCollection: PartialEq + Default + Clone { + /// Drains and returns `count` ids from the front of the collection, modifying the collection. + /// If the length of the collection is less than `count`, drains the entire collection. + /// + /// # Panics + /// + /// Panics if the `count` is invalid. + /// + /// # Arguments + /// + /// * `count` - The amount of ids to drain. + fn drain_front(&mut self, count: usize) -> Self; + + /// Returns the id of an elements at the given `index` in the collection. + /// + /// # Arguments + /// + /// * `index` - The index of an id. + /// + /// # Panics + /// + /// Panics if there is no id with the given index in the collection. + fn id(&self, index: usize) -> Id; + + /// Returns the amount of ids in the collection. + fn len(&self) -> usize; + + /// Whether the collection is empty. + fn is_empty(&self) -> bool; + + /// Constructs a collection from an iterator over collections. + /// + /// # Panics + /// + /// Panics if a collection cannot be constructed. + fn new_from_iter>(iter: I) -> Self; +} diff --git a/crates/components/authdecode/authdecode-core/src/lib.rs b/crates/components/authdecode/authdecode-core/src/lib.rs new file mode 100644 index 0000000000..9a06b7458d --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/lib.rs @@ -0,0 +1,284 @@ +//! Implementation of the AuthDecode protocol. +//! +//! The protocol performs authenticated decoding of encodings in zero knowledge. +//! +//! One of the use cases of AuthDecode is for the garbled circuits (GC) evaluator to produce a +//! zk-friendly hash commitment to either the GC input or the GC output, where computing such a +//! commitment directly using GC would be prohibitively expensive. +//! +//! The protocol consists of the following steps: +//! 1. The Prover commits to both the plaintext and the arithmetic sum of the active encodings of the +//! bits of the plaintext. (The protocol assumes that the Prover ascertained beforehand that the +//! active encodings are authentic.) +//! 2. The Prover obtains the full encodings of the plaintext bits from some outer context and uses +//! them to create a zk proof, proving that during Step 1. they knew the correct active encodings +//! of the plaintext and also proving that a hash commitment H is an authentic commitment to the +//! plaintext. +//! 3. The Verifier verifies the proof and accepts H as an authentic hash commitment to the plaintext. +//! +//! Important: when using the protocol, you must ensure that the Prover obtains the full encodings +//! from an outer context only **after** they've made a commitment in Step 1. + +pub mod backend; +pub mod encodings; +pub mod id; +pub mod msgs; +pub mod prover; +pub mod verifier; + +#[cfg(any(test, feature = "fixtures"))] +pub mod fixtures; +#[cfg(any(test, feature = "mock"))] +pub mod mock; + +pub use prover::Prover; +pub use verifier::Verifier; + +use serde::{Deserialize, Serialize}; + +/// The statistical security parameter used by the protocol. +pub const SSP: usize = 40; + +/// An opaque proof. +#[derive(Clone, Default, Serialize, Deserialize, Debug)] +pub struct Proof(Vec); +impl Proof { + /// Creates a new proof from bytes. + /// + /// # Arguments + /// + /// * `bytes` - The bytes from which to create the proof. + pub fn new(bytes: &[u8]) -> Self { + Self(bytes.to_vec()) + } +} + +/// Public inputs to the AuthDecode circuit. +#[derive(Clone, Default)] +pub struct PublicInput { + /// The hash commitment to the plaintext. + plaintext_hash: F, + /// The hash commitment to the sum of the encodings. + encoding_sum_hash: F, + /// The sum of the encodings which encode the value 0 of a bit . + zero_sum: F, + /// An arithmetic difference between the encoding of bit value 1 and encoding of bit value 0 for + /// each bit of the plaintext in MSB0 bit order. + deltas: Vec, +} + +#[cfg(test)] +mod tests { + use crate::{ + backend::traits::{Field, ProverBackend, VerifierBackend}, + fixtures, + mock::{MockBitIds, MockEncodingProvider}, + prover::{CommitmentData, ProofGenerated, Prover, ProverInput}, + verifier::{VerifiedSuccessfully, Verifier}, + Proof, + }; + + use rstest::*; + use serde::{de::DeserializeOwned, Serialize}; + use std::{ + any::Any, + cell::RefCell, + ops::{Add, Sub}, + }; + + #[fixture] + #[once] + fn commitment_data() -> Vec> { + fixtures::commitment_data() + } + + #[fixture] + #[once] + fn encoding_provider() -> MockEncodingProvider { + fixtures::encoding_provider() + } + + // Tests the protocol with a mock backend. + #[rstest] + fn test_mock_backend( + commitment_data: &[CommitmentData], + encoding_provider: &MockEncodingProvider, + ) { + run_authdecode( + crate::backend::mock::backend_pair(), + commitment_data, + encoding_provider, + ); + } + + // Tests the protocol with a mock halo2 prover and verifier. + #[rstest] + fn test_mock_halo2_backend( + commitment_data: &[CommitmentData], + encoding_provider: &MockEncodingProvider, + ) { + run_authdecode( + crate::backend::halo2::fixtures::backend_pair_mock(), + commitment_data, + encoding_provider, + ); + } + + // Tests the protocol with a halo2 prover and verifier.. + #[ignore = "expensive"] + #[rstest] + fn test_halo2_backend( + commitment_data: &[CommitmentData], + encoding_provider: &MockEncodingProvider, + ) { + run_authdecode( + crate::backend::halo2::fixtures::backend_pair(), + commitment_data, + encoding_provider, + ); + } + + // Runs the protocol with the given backends. + // Returns the prover and the verifier in their finalized state. + #[allow(clippy::type_complexity)] + fn run_authdecode( + pair: ( + impl ProverBackend + 'static, + impl VerifierBackend + 'static, + ), + commitment_data: &[CommitmentData], + encoding_provider: &MockEncodingProvider, + ) -> ( + Prover, F>, + Verifier, F>, + ) + where + F: Field + Add + Sub + Serialize + DeserializeOwned + Clone, + { + let prover = Prover::new(Box::new(pair.0)); + let verifier = Verifier::new(Box::new(pair.1)); + + let (prover, commitments) = prover.commit(commitment_data.to_vec()).unwrap(); + + // Message types are checked during deserialization. + let commitments = bincode::serialize(&commitments).unwrap(); + let commitments = bincode::deserialize(&commitments).unwrap(); + + let verifier = verifier.receive_commitments(commitments).unwrap(); + + // An encoding provider is instantiated with authenticated full encodings from external context. + let (prover, proofs) = prover.prove(encoding_provider).unwrap(); + + // Message types are checked durind deserialization. + let proofs = bincode::serialize(&proofs).unwrap(); + let proofs = bincode::deserialize(&proofs).unwrap(); + + let verifier = verifier.verify(proofs, encoding_provider).unwrap(); + + (prover, verifier) + } + + // Returns valid `ProofInput`s for the given backend pair which can be used as a fixture in + // backend tests. + pub fn proof_inputs_for_backend< + F: Field + Add + Sub + Serialize + DeserializeOwned + Clone + 'static, + >( + prover: impl ProverBackend + 'static, + verifier: impl VerifierBackend + 'static, + ) -> Vec> { + // Wrap the prover backend. + struct ProverBackendWrapper { + prover: Box>, + proof_inputs: RefCell>>>, + } + + impl ProverBackend for ProverBackendWrapper + where + F: Field + + Add + + Sub + + Serialize + + DeserializeOwned + + Clone + + 'static, + { + fn chunk_size(&self) -> usize { + self.prover.chunk_size() + } + + fn commit_encoding_sum(&self, encoding_sum: F) -> (F, F) { + self.prover.commit_encoding_sum(encoding_sum) + } + + fn commit_plaintext(&self, plaintext: Vec) -> (F, F) { + self.prover.commit_plaintext(plaintext) + } + + fn commit_plaintext_with_salt(&self, _plaintext: Vec, _salt: F) -> F { + unimplemented!() + } + + fn prove( + &self, + input: Vec>, + ) -> Result, crate::prover::ProverError> { + // Save proof inputs, return a dummy proof. + *self.proof_inputs.borrow_mut() = Some(input); + Ok(vec![Proof::new(&[0u8])]) + } + + fn as_any(&self) -> &dyn Any { + self + } + } + + // Wrap the verifier backend. + struct VerifierBackendWrapper { + verifier: Box>, + } + + impl VerifierBackend for VerifierBackendWrapper + where + F: Field + Add + Sub + Serialize + DeserializeOwned + Clone, + { + fn chunk_size(&self) -> usize { + self.verifier.chunk_size() + } + + fn verify( + &self, + _inputs: Vec>, + _proofs: Vec, + ) -> Result<(), crate::verifier::VerifierError> { + Ok(()) + } + } + + // Instantiate the backend pair. + let prover_wrapper = ProverBackendWrapper { + prover: Box::new(prover), + proof_inputs: RefCell::new(None), + }; + let verifier_wrapper = VerifierBackendWrapper { + verifier: Box::new(verifier), + }; + + // Run the protocol. + let (prover, _) = run_authdecode( + (prover_wrapper, verifier_wrapper), + &commitment_data(), + &encoding_provider(), + ); + + // Extract proof inputs from the backend. + prover + .backend() + .as_any() + .downcast_ref::>() + .unwrap() + .proof_inputs + .borrow() + .clone() + .unwrap() + } +} diff --git a/crates/components/authdecode/authdecode-core/src/mock.rs b/crates/components/authdecode/authdecode-core/src/mock.rs new file mode 100644 index 0000000000..480f5c57eb --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/mock.rs @@ -0,0 +1,240 @@ +use crate::{ + encodings::{Encoding, EncodingProvider, EncodingProviderError, FullEncodings}, + id::{Id, IdCollection}, +}; + +use itybity::{FromBitIterator, ToBits}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::{HashMap, VecDeque}, + marker::PhantomData, + ops::Range, +}; + +/// The direction of the transcript. +#[derive(Clone, PartialEq, Serialize, Deserialize, Default, Debug)] +pub enum Direction { + #[default] + Sent, + Received, +} + +/// A collection of ids of transcript bits. Each bit is uniquely identified by the transcript's direction +/// and the bit's index in the transcript. +/// Ranges may overlap. +#[derive(Clone, PartialEq, Serialize, Deserialize, Default, Debug)] +pub struct MockBitIds { + /// The direction of the transcript. + direction: Direction, + /// Ranges of bits in the transcript. The ranges may overlap. + ranges: VecDeque>, +} + +impl MockBitIds { + /// Constructs a new collection from ids in the given **byte** `ranges`. + pub fn new(direction: Direction, ranges: &[Range]) -> Self { + // Convert to bit ranges. + let ranges = ranges + .iter() + .map(|r| Range { + start: r.start * 8, + end: r.end * 8, + }) + .collect::>(); + Self { direction, ranges } + } + + /// Encodes the direction and the bit's `offset` in the transcript into an id. + /// + /// # Panics + /// + /// Panics if `offset` > 2^32. + fn encode_bit_id(&self, offset: usize) -> Id { + // All values are encoded in MSB-first order. + // The first bit encodes the direction, the remaining bits encode the offset. + let mut id = vec![false; 64]; + let encoded_direction = if self.direction == Direction::Sent { + [false] + } else { + [true] + }; + + assert!(offset < (1 << 32)); + + let encoded_offset = (offset as u32).to_be_bytes().to_msb0_vec(); + + id[0..1].copy_from_slice(&encoded_direction); + id[1 + (63 - encoded_offset.len())..].copy_from_slice(&encoded_offset); + + Id(u64::from_be_bytes( + boolvec_to_u8vec(&id).try_into().unwrap(), + )) + } + + /// Decodes bit id into the direction and the bit's offset in the transcript. + #[allow(dead_code)] + fn decode_bit_id(&self, id: Id) -> (Direction, usize) { + let encoding = id.0.to_be_bytes().to_msb0_vec(); + let direction_encoding = &encoding[0..1]; + + let direction = if direction_encoding == [false] { + Direction::Sent + } else { + Direction::Received + }; + + let offset_encoding = &encoding[1..]; + let offset = usize::from_be_bytes(boolvec_to_u8vec(offset_encoding).try_into().unwrap()); + + (direction, offset) + } +} + +impl IdCollection for MockBitIds { + fn drain_front(&mut self, mut count: usize) -> Self { + let mut drained_ranges: VecDeque> = VecDeque::new(); + + while count > 0 { + let mut range = match self.ranges.remove(0) { + None => { + // Nothing more to drain. + break; + } + Some(range) => range, + }; + + // It is safe to `unwrap()` here and below since all iters/ranges will contain at least + // 1 element. + let min = range.clone().min().unwrap(); + let yielded = range.by_ref().take(count); + let max = yielded.max().unwrap() + 1; + drained_ranges.push_back(Range { + start: min, + end: max, + }); + + // If the range was only partially drained, put back the undrained subrange. + if !range.is_empty() { + self.ranges.push_back(Range { + start: range.clone().min().unwrap(), + end: range.max().unwrap() + 1, + }); + break; + } + + count -= max - min; + } + + Self { + direction: self.direction.clone(), + // Optimization: combine adjacent ranges. + ranges: drained_ranges, + } + } + + fn id(&self, index: usize) -> Id { + assert!(self.len() > index); + // How many indices already checked. + let mut checked = 0; + + // Find which range the `index` is located in. + for r in &self.ranges { + if checked + r.len() > index { + // Offset of the `index` from the start of this range. + let offset = index - checked; + return self.encode_bit_id(r.start + offset); + } + checked += r.len(); + } + + unreachable!() + } + + fn len(&self) -> usize { + self.ranges.iter().map(|r| r.len()).sum() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn new_from_iter>(iter: I) -> Self { + let mut direction = None; + let ranges = iter + .into_iter() + .flat_map(|i| { + if let Some(dir) = &direction { + assert!(dir == &i.direction) + } else { + // On first iteration, set the direction. + direction = Some(i.direction) + } + i.ranges + }) + .collect::>(); + + Self { + direction: direction.unwrap(), + // Optimization: combine adjacent ranges. + ranges, + } + } +} + +/// A mock provider of encodings. +#[derive(Clone)] +pub struct MockEncodingProvider +where + T: IdCollection, +{ + /// A mapping from a bit id to the full encoding of the bit. + full_encodings: HashMap, + phantom: PhantomData, +} + +impl MockEncodingProvider +where + T: IdCollection, +{ + pub fn new(full_encodings: FullEncodings) -> Self { + let mut hashmap = HashMap::new(); + let ids = (0..full_encodings.ids().len()) + .map(|idx| full_encodings.ids().id(idx)) + .collect::>(); + + for (full_enc, id) in full_encodings.encodings().iter().zip(ids) { + if hashmap.insert(id.clone(), *full_enc).is_some() { + panic!("duplicate ids detected"); + } + } + Self { + full_encodings: hashmap, + phantom: PhantomData, + } + } +} + +impl EncodingProvider for MockEncodingProvider +where + T: IdCollection, +{ + fn get_by_ids(&self, ids: &T) -> Result, EncodingProviderError> { + let all_ids = (0..ids.len()).map(|idx| ids.id(idx)).collect::>(); + + let full_encodings = all_ids + .iter() + .map(|id| *self.full_encodings.get(id).unwrap()) + .collect::>(); + Ok(FullEncodings::new(full_encodings, ids.clone())) + } +} + +/// Converts bits in MSB-first order into BE bytes. The bits will be internally left-padded +/// with zeroes to the nearest multiple of 8. +fn boolvec_to_u8vec(bv: &[bool]) -> Vec { + // Reverse to lsb0 since `itybity` can only pad the rightmost bits. + let mut b = Vec::::from_lsb0_iter(bv.iter().rev().copied()); + // Reverse to get big endian byte order. + b.reverse(); + b +} diff --git a/crates/components/authdecode/authdecode-core/src/msgs.rs b/crates/components/authdecode/authdecode-core/src/msgs.rs new file mode 100644 index 0000000000..8bf0b9fd13 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/msgs.rs @@ -0,0 +1,228 @@ +//! Protocol messages and types contained therein. + +use crate::{ + backend::traits::Field, + id::IdCollection, + prover::CommitmentDetails, + verifier::{UnverifiedChunkCommitment, UnverifiedCommitment}, + Proof, +}; + +use enum_try_as_inner::EnumTryAsInner; +use serde::{Deserialize, Serialize}; + +/// A protocol message. +#[derive(Debug, Clone, Serialize, EnumTryAsInner, Deserialize)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum Message { + Commit(Commit), + Proofs(Proofs), +} + +impl From> for std::io::Error +where + I: IdCollection, + F: Field, +{ + fn from(err: MessageError) -> Self { + std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string()) + } +} + +/// A commitment message sent by the prover. +#[derive(Clone, Serialize, Deserialize, Debug)] +#[serde(try_from = "UncheckedCommit")] +pub struct Commit +where + I: IdCollection, + F: Field, +{ + /// A non-empty collection of commitments. Each element is a commitment to plaintext of an + /// arbitrary length. + commitments: Vec>, +} + +impl Commit +where + I: IdCollection, + F: Field, +{ + /// Returns the total number of chunks across all commitments in the collection. + pub fn chunk_count(&self) -> usize { + self.commitments + .iter() + .map(|inner| inner.chunk_commitments.len()) + .sum() + } + + /// Returns the total number of commitments in the collection. + pub fn commitment_count(&self) -> usize { + self.commitments.len() + } +} + +/// A message with proofs sent by the prover. +#[derive(Clone, Serialize, Deserialize, Debug)] +#[serde(try_from = "UncheckedProofs")] +pub struct Proofs { + pub proofs: Vec, +} + +impl Commit +where + I: IdCollection, + F: Field, +{ + /// Converts this message into a collection of unverified commitments which the verifier can + /// work with. + /// + /// # Arguments + /// * `max_size` - The expected maximum bytesize of a chunk of plaintext committed to. + pub fn into_vec_commitment( + self, + max_size: usize, + ) -> Result>, std::io::Error> { + self.commitments + .into_iter() + .map(|com| { + let chunk_com = com + .chunk_commitments + .into_iter() + .map(|chunk_com| { + if chunk_com.ids.len() > max_size * 8 { + Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "The length of ids is larger than the chunk size.", + )) + } else { + Ok(UnverifiedChunkCommitment::new( + chunk_com.plaintext_hash, + chunk_com.encoding_sum_hash, + chunk_com.ids, + )) + } + }) + .collect::, std::io::Error>>()?; + + Ok(UnverifiedCommitment::new(chunk_com)) + }) + .collect::, std::io::Error>>() + } +} + +impl From>> for Commit +where + I: IdCollection, + F: Field + Clone, +{ + fn from(source: Vec>) -> Commit { + Commit { + commitments: source + .into_iter() + .map(|com| { + let chunk_commitments = com + .chunk_commitments() + .iter() + .map(|chunk_com| ChunkCommitment { + plaintext_hash: chunk_com.plaintext_hash().clone(), + encoding_sum_hash: chunk_com.encoding_sum_hash().clone(), + ids: chunk_com.ids().clone(), + }) + .collect::>(); + Commitment { chunk_commitments } + }) + .collect::>(), + } + } +} + +/// A single commitment to plaintext of an arbitrary length. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Commitment +where + I: IdCollection, + F: Field, +{ + /// A non-empty collection of commitments to each chunk of the plaintext. + chunk_commitments: Vec>, +} + +/// A commitment to a single chunk of plaintext. +#[derive(Clone, Serialize, Deserialize, Debug)] +struct ChunkCommitment +where + I: IdCollection, + F: Field, +{ + /// Hash commitment to the plaintext. + plaintext_hash: F, + /// Hash commitment to the `encoding_sum`. + encoding_sum_hash: F, + /// The id of each bit of the plaintext. + ids: I, +} + +/// A [`Commit`] message in its unchecked state. +#[derive(Deserialize)] +pub struct UncheckedCommit +where + I: IdCollection, + F: Field, +{ + commitments: Vec>, +} + +impl TryFrom> for Commit +where + I: IdCollection, + F: Field, +{ + type Error = std::io::Error; + + fn try_from(value: UncheckedCommit) -> Result { + // None of the commitment vectors should be empty. + if value.commitments.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "empty commitment vector", + )); + } + + for com in &value.commitments { + if com.chunk_commitments.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "empty chunk commitment vector", + )); + } + } + + Ok(Commit { + commitments: value.commitments, + }) + } +} + +#[derive(Deserialize)] +/// A [`Proof`] message in its unchecked state. +pub struct UncheckedProofs { + proofs: Vec, +} + +impl TryFrom for Proofs { + type Error = std::io::Error; + + fn try_from(value: UncheckedProofs) -> Result { + if value.proofs.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "empty proof vector", + )); + } + + Ok(Proofs { + proofs: value.proofs, + }) + } +} diff --git a/crates/components/authdecode/authdecode-core/src/prover.rs b/crates/components/authdecode/authdecode-core/src/prover.rs new file mode 100644 index 0000000000..f63a9afe66 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/prover.rs @@ -0,0 +1,224 @@ +use crate::{ + backend::traits::{Field, ProverBackend as Backend}, + encodings::EncodingProvider, + id::IdCollection, + msgs::{Commit, Proofs}, + PublicInput, +}; + +use getset::Getters; +use std::{marker::PhantomData, ops::Add}; + +#[cfg(feature = "tracing")] +use tracing::{debug, debug_span, instrument, Instrument}; + +mod commitment; +mod error; +mod state; + +pub use commitment::{CommitmentData, CommitmentDetails}; +pub use error::ProverError; +pub use state::{Committed, Initialized, ProofGenerated, ProverState}; + +/// The prover's public and private inputs to the circuit. +#[derive(Clone, Default, Getters)] +pub struct ProverInput { + /// The public input. + #[getset(get = "pub")] + public: PublicInput, + /// The private input. + #[getset(get = "pub")] + private: PrivateInput, +} + +/// Private inputs to the AuthDecode circuit. +#[derive(Clone, Default, Getters)] +pub struct PrivateInput { + /// The plaintext committed to. + #[getset(get = "pub")] + plaintext: Vec, + /// The salt used to create the commitment to the plaintext. + #[getset(get = "pub")] + plaintext_salt: F, + /// The salt used to create the commitment to the sum of the encodings. + #[getset(get = "pub")] + encoding_sum_salt: F, +} + +/// Prover in the AuthDecode protocol. +pub struct Prover { + /// The zk backend. + backend: Box>, + /// The current state of the prover. + state: S, + pd: PhantomData, +} + +impl Prover +where + I: IdCollection, + F: Field + Add, +{ + /// Creates a new prover. + /// + /// # Arguments + /// + /// * `backend` - The zk backend. + pub fn new(backend: Box>) -> Self { + Prover { + backend, + state: state::Initialized {}, + pd: PhantomData, + } + } + + /// Creates a commitment to each element in the `data_set`. + /// + /// Returns the prover in a new state and the message to be passed to the verifier. + /// + /// # Arguments + /// + /// * `data_set` - The set of commitment data to be committed to. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + #[allow(clippy::type_complexity)] + pub fn commit( + self, + data_set: Vec>, + ) -> Result<(Prover, F>, Commit), ProverError> + where + I: IdCollection, + F: Field + Clone + std::ops::Add, + { + // Commit to each commitment data in the set individually. + let commitments = data_set + .into_iter() + .map(|data| data.commit(&self.backend)) + .collect::>, ProverError>>()?; + + Ok(( + Prover { + backend: self.backend, + state: Committed { + commitments: commitments.clone(), + }, + pd: PhantomData, + }, + commitments.into(), + )) + } + + /// Creates a commitment to each element in the `data_set` with the provided salts. + /// + /// Returns the prover in a new state and the message to be passed to the verifier. + /// + /// # Arguments + /// + /// * `data_set` - The set of commitment data with salts for each chunk of it. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + #[allow(clippy::type_complexity)] + pub fn commit_with_salt( + self, + data_set: Vec<(CommitmentData, Vec)>, + ) -> Result<(Prover, F>, Commit), ProverError> + where + I: IdCollection, + F: Field + Clone + std::ops::Add, + { + // Commit to each element in the set individually. + let commitments = data_set + .into_iter() + .map(|(data, salt)| data.commit_with_salt(&self.backend, salt)) + .collect::>, ProverError>>()?; + + Ok(( + Prover { + backend: self.backend, + state: Committed { + commitments: commitments.clone(), + }, + pd: PhantomData, + }, + commitments.into(), + )) + } +} + +impl Prover, F> +where + I: IdCollection, + F: Field + Clone + std::ops::Sub + std::ops::Add, +{ + /// Generates zk proofs. + /// + /// Returns the prover in a new state and the message to be passed to the verifier. + /// + /// # Arguments + /// + /// * `encoding_provider` - The provider of full encodings for the plaintext committed to + /// earlier. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + #[allow(clippy::type_complexity)] + pub fn prove( + self, + encoding_provider: &impl EncodingProvider, + ) -> Result<(Prover, F>, Proofs), ProverError> { + // Collect proof inputs for each chunk of plaintext committed to. + let proof_inputs = self + .state + .commitments + .clone() + .into_iter() + .flat_map(|com| { + let coms = com + .chunk_commitments() + .iter() + .map(|com| { + let full_encodings = encoding_provider.get_by_ids(com.ids())?; + + Ok(ProverInput { + public: PublicInput { + deltas: full_encodings.compute_deltas::(), + plaintext_hash: com.plaintext_hash().clone(), + encoding_sum_hash: com.encoding_sum_hash().clone(), + zero_sum: full_encodings.compute_zero_sum(), + }, + private: PrivateInput { + plaintext: com.encodings().plaintext(), + plaintext_salt: com.plaintext_salt().clone(), + encoding_sum_salt: com.encoding_sum_salt().clone(), + }, + }) + }) + .collect::, ProverError>>()?; + + Ok::>, ProverError>(coms) + }) + .flatten() + .collect::>(); + + let proofs = self.backend.prove(proof_inputs)?; + + Ok(( + Prover { + backend: self.backend, + state: ProofGenerated { + commitments: self.state.commitments, + }, + pd: PhantomData, + }, + Proofs { proofs }, + )) + } +} + +#[cfg(any(test, feature = "fixtures"))] +impl Prover, F> +where + I: IdCollection, + F: Field + Clone + std::ops::Sub + std::ops::Add, +{ + // Testing only. Returns the backend that can be downcast to a concrete type. + pub fn backend(self) -> Box> { + self.backend + } +} diff --git a/crates/components/authdecode/authdecode-core/src/prover/commitment.rs b/crates/components/authdecode/authdecode-core/src/prover/commitment.rs new file mode 100644 index 0000000000..936b6c4943 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/prover/commitment.rs @@ -0,0 +1,252 @@ +use crate::{ + backend::traits::{Field, ProverBackend as Backend}, + encodings::{active::ActiveEncodingsChunks, ActiveEncodings, Encoding}, + id::IdCollection, + prover::error::ProverError, + SSP, +}; + +use getset::Getters; +use itybity::ToBits; + +/// The plaintext and the encodings which the prover commits to. +#[derive(Clone, Default)] +pub struct CommitmentData +where + I: IdCollection, +{ + encodings: ActiveEncodings, +} + +impl CommitmentData +where + I: IdCollection, +{ + /// Creates a commitment to this commitment data. + #[allow(clippy::borrowed_box)] + pub fn commit( + self, + backend: &Box>, + ) -> Result, ProverError> + where + F: Field + Clone + std::ops::Add, + { + // Chunk up the data and commit to each chunk individually. + let chunk_commitments = self + .into_chunks(backend.chunk_size()) + .map(|data_chunk| data_chunk.commit(backend)) + .collect::>>(); + + Ok(CommitmentDetails { chunk_commitments }) + } + + /// Creates a commitment to this commitment data with the provided plaintext salt. + /// + /// Returns an error if the amount of salts is not equal to the amount of chunks. + #[allow(clippy::borrowed_box)] + pub fn commit_with_salt( + self, + backend: &Box>, + salts: Vec, + ) -> Result, ProverError> + where + F: Field + Clone + std::ops::Add, + { + // Chunk up the data. + let chunks = self.into_chunks(backend.chunk_size()).collect::>(); + + if chunks.len() < salts.len() { + return Err(ProverError::MismatchedSaltChunkCount); + } + + let chunk_commitments = chunks + .into_iter() + .zip(salts) + .map(|(chunk, salt)| chunk.commit_with_salt(backend, salt)) + .collect::>>(); + + Ok(CommitmentDetails { chunk_commitments }) + } + + /// Creates new commitment data. + /// + /// # Arguments + /// * `plaintext` - The plaintext being committed to. + /// * `encodings` - Uniformly random encodings of every bit of the `plaintext` in MSB0 bit order. + /// Note that correlated encodings like those used in garbled circuits must + /// not be used since they are not uniformly random. + /// * `bit_ids` - The id of each bit of the `plaintext`. + /// + /// # Panics + /// + /// Panics if `plaintext`, `encodings` and `bit_ids` are not all of the same length. + pub fn new(plaintext: &[u8], encodings: &[[u8; SSP / 8]], bit_ids: I) -> CommitmentData { + assert!(plaintext.len() * 8 == encodings.len()); + assert!(encodings.len() == bit_ids.len()); + + let encodings = plaintext + .to_msb0_vec() + .into_iter() + .zip(encodings) + .map(|(bit, enc)| Encoding::new(*enc, bit)) + .collect::>(); + + CommitmentData { + encodings: ActiveEncodings::new(encodings, bit_ids), + } + } + + /// Convert `self` into an iterator over chunks of the commitment data. If `chunk_size` does not + /// divide the length of the commitment data, then the last chunk will not have length `chunk_size`. + /// + /// # Arguments + /// + /// * `chunk_size` - The size of a chunk. + pub fn into_chunks(self, chunk_size: usize) -> CommitmentDataChunks { + CommitmentDataChunks { + encodings: self.encodings.clone().into_chunks(chunk_size * 8), + } + } +} + +pub struct CommitmentDataChunks { + encodings: ActiveEncodingsChunks, +} + +impl Iterator for CommitmentDataChunks +where + I: IdCollection, +{ + type Item = CommitmentDataChunk; + + fn next(&mut self) -> Option { + self.encodings + .next() + .map(|encodings| Some(CommitmentDataChunk { encodings }))? + } +} + +/// A chunk of data that needs to be committed to. +pub struct CommitmentDataChunk +where + I: IdCollection, +{ + /// The active encoding of each bit of the plaintext. The number of encodings is always a + /// multiple of 8. + encodings: ActiveEncodings, +} + +impl CommitmentDataChunk +where + I: IdCollection, +{ + /// Creates a commitment to this chunk. + #[allow(clippy::borrowed_box)] + fn commit(&self, backend: &Box>) -> ChunkCommitmentDetails + where + F: Field + Clone + std::ops::Add, + { + let sum = self.encodings.compute_sum::(); + + let (plaintext_hash, plaintext_salt) = backend.commit_plaintext(self.encodings.plaintext()); + + let (encoding_sum_hash, encoding_sum_salt) = backend.commit_encoding_sum(sum.clone()); + + ChunkCommitmentDetails { + plaintext_hash, + plaintext_salt, + encodings: self.encodings.clone(), + encoding_sum: sum, + encoding_sum_hash, + encoding_sum_salt, + } + } + + /// Creates a commitment to this chunk with the provided salt. + #[allow(clippy::borrowed_box)] + fn commit_with_salt( + &self, + backend: &Box>, + salt: F, + ) -> ChunkCommitmentDetails + where + F: Field + Clone + std::ops::Add, + { + let sum = self.encodings.compute_sum::(); + + let plaintext_hash = + backend.commit_plaintext_with_salt(self.encodings.plaintext(), salt.clone()); + + let (encoding_sum_hash, encoding_sum_salt) = backend.commit_encoding_sum(sum.clone()); + + ChunkCommitmentDetails { + plaintext_hash, + plaintext_salt: salt, + encodings: self.encodings.clone(), + encoding_sum: sum, + encoding_sum_hash, + encoding_sum_salt, + } + } +} + +/// An AuthDecode commitment to a single chunk of plaintext with the associated details. +#[derive(Clone, Getters)] +pub struct ChunkCommitmentDetails { + /// Hash commitment to the plaintext. + #[getset(get = "pub")] + plaintext_hash: F, + /// The salt used to create the commitment to the plaintext. + #[getset(get = "pub")] + plaintext_salt: F, + /// The encodings the sum of which is committed to. + #[getset(get = "pub")] + encodings: ActiveEncodings, + /// The sum of the encodings. + #[getset(get = "pub")] + encoding_sum: F, + /// Hash commitment to the `encoding_sum`. + #[getset(get = "pub")] + encoding_sum_hash: F, + /// The salt used to create the commitment to the `encoding_sum`. + #[getset(get = "pub")] + encoding_sum_salt: F, +} + +impl ChunkCommitmentDetails +where + I: IdCollection, + F: Field, +{ + /// Returns the id of each bit of the plaintext. + pub fn ids(&self) -> &I { + self.encodings.ids() + } +} + +/// An AuthDecode commitment to plaintext of arbitrary length with the associated details. +#[derive(Clone, Default, Getters)] +pub struct CommitmentDetails { + /// Commitments to each chunk of the plaintext with the associated details. + /// + /// Internally, for performance reasons, the data to be committed to is split up into chunks + /// and each chunk is committed to separately. The collection of chunk commitments constitutes + /// the commitment. + #[getset(get = "pub")] + chunk_commitments: Vec>, +} + +impl CommitmentDetails +where + I: IdCollection + Clone, + F: Field + Clone, +{ + /// Returns the encodings of the plaintext of this commitment. + pub fn encodings(&self) -> ActiveEncodings { + let iter = self + .chunk_commitments + .iter() + .map(|enc| enc.encodings.clone()); + ActiveEncodings::new_from_iter(iter) + } +} diff --git a/crates/components/authdecode/authdecode-core/src/prover/error.rs b/crates/components/authdecode/authdecode-core/src/prover/error.rs new file mode 100644 index 0000000000..6f3a9ea41b --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/prover/error.rs @@ -0,0 +1,13 @@ +#[derive(Debug, thiserror::Error)] +pub enum ProverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("The proof system returned an error when generating a proof: {0}")] + ProvingBackendError(String), + #[error(transparent)] + EncodingProviderError(#[from] crate::encodings::EncodingProviderError), + #[error("A mismatched count of salts for the commitment data set")] + MismatchedSaltCommitmentDataCount, + #[error("A mismatched count of salts for the chunk count")] + MismatchedSaltChunkCount, +} diff --git a/crates/components/authdecode/authdecode-core/src/prover/state.rs b/crates/components/authdecode/authdecode-core/src/prover/state.rs new file mode 100644 index 0000000000..1c123f62b1 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/prover/state.rs @@ -0,0 +1,54 @@ +//! AuthDecode prover states. + +use crate::{backend::traits::Field, id::IdCollection, prover::commitment::CommitmentDetails}; + +/// The state of the Prover throughout the AuthDecode protocol. +pub trait ProverState: sealed::Sealed {} + +/// The initial state. +pub struct Initialized {} +opaque_debug::implement!(Initialized); + +/// The state after the prover has made a commitment. +pub struct Committed { + pub commitments: Vec>, +} +opaque_debug::implement!(Committed); + +/// The state after the prover generated proofs. +pub struct ProofGenerated { + pub commitments: Vec>, +} +opaque_debug::implement!(ProofGenerated); + +impl ProverState for Initialized {} +impl ProverState for Committed +where + I: IdCollection, + F: Field + Clone, +{ +} +impl ProverState for ProofGenerated +where + I: IdCollection, + F: Field + Clone, +{ +} + +mod sealed { + use crate::{id::IdCollection, prover::state::Field}; + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Committed + where + I: IdCollection, + F: Field + Clone, + { + } + impl Sealed for super::ProofGenerated + where + I: IdCollection, + F: Field + Clone, + { + } +} diff --git a/crates/components/authdecode/authdecode-core/src/verifier.rs b/crates/components/authdecode/authdecode-core/src/verifier.rs new file mode 100644 index 0000000000..6242fc2a8d --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/verifier.rs @@ -0,0 +1,143 @@ +use std::marker::PhantomData; + +use crate::{ + backend::traits::{Field, VerifierBackend as Backend}, + encodings::EncodingProvider, + id::IdCollection, + msgs::{Commit, Proofs}, + PublicInput, +}; + +#[cfg(feature = "tracing")] +use tracing::{debug, debug_span, instrument, Instrument}; + +mod commitment; +mod error; +mod state; + +pub use commitment::VerifiedCommitment; +pub(crate) use commitment::{UnverifiedChunkCommitment, UnverifiedCommitment}; +pub use error::VerifierError; +pub use state::{CommitmentReceived, Initialized, VerifiedSuccessfully, VerifierState}; + +/// Verifier in the AuthDecode protocol. +pub struct Verifier +where + I: IdCollection, + F: Field, + S: state::VerifierState, +{ + /// The backend for zk proof verification. + backend: Box>, + /// The state of the verifier. + state: S, + phantom: PhantomData, +} + +impl Verifier +where + I: IdCollection, + F: Field, +{ + /// Creates a new verifier. + /// + /// # Arguments + /// + /// `backend` - The backend for zk proof verification + pub fn new(backend: Box>) -> Self { + Verifier { + backend, + state: state::Initialized {}, + phantom: PhantomData, + } + } + + /// Receives the commitments and stores them. + /// + /// Returns the verifier in a new state. + /// + /// # Arguments + /// + /// * `commitments` - The prover's message containing commitments. + /// * `encoding_provider` - The provider of full encodings. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub fn receive_commitments( + self, + commitments: Commit, + ) -> Result, F>, VerifierError> { + let commitments: Vec> = + commitments.into_vec_commitment(self.backend.chunk_size())?; + + Ok(Verifier { + backend: self.backend, + state: state::CommitmentReceived { commitments }, + phantom: PhantomData, + }) + } +} + +impl Verifier, F> +where + I: IdCollection, + F: Field + std::ops::Add + std::ops::Sub + Clone, +{ + /// Verifies proofs for the commitments received earlier. + /// + /// Returns the verifier in a new state. + /// + /// # Arguments + /// * `proofs` - The prover's message containing proofs. + /// * `encoding_provider` - The provider of the encodings for plaintext bits. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub fn verify( + self, + proofs: Proofs, + encoding_provider: &(impl EncodingProvider + 'static), + ) -> Result, F>, VerifierError> { + let Proofs { proofs } = proofs; + + // Compute public inputs to verify each chunk of plaintext committed to. + let public_inputs = self + .state + .commitments + .iter() + .flat_map(|com| com.chunk_commitments()) + .map(|com| { + let encodings = encoding_provider.get_by_ids(com.ids())?; + + Ok(PublicInput { + plaintext_hash: com.plaintext_hash().clone(), + encoding_sum_hash: com.encoding_sum_hash().clone(), + zero_sum: encodings.compute_zero_sum(), + deltas: encodings.compute_deltas(), + }) + }) + .collect::, VerifierError>>()?; + + self.backend.verify(public_inputs, proofs)?; + + Ok(Verifier { + backend: self.backend, + state: state::VerifiedSuccessfully { + commitments: self + .state + .commitments + .into_iter() + .map(|com| com.into()) + .collect(), + }, + phantom: PhantomData, + }) + } +} + +impl Verifier, F> +where + I: IdCollection, + F: Field + std::ops::Add + std::ops::Sub + Clone, +{ + /// Returns the verified comitments. + pub fn commitments(&self) -> &Vec> { + &self.state.commitments + } +} diff --git a/crates/components/authdecode/authdecode-core/src/verifier/commitment.rs b/crates/components/authdecode/authdecode-core/src/verifier/commitment.rs new file mode 100644 index 0000000000..68aeac1014 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/verifier/commitment.rs @@ -0,0 +1,180 @@ +use crate::{backend::traits::Field, id::IdCollection}; + +use getset::Getters; + +/// A yet-unverified commitment to plaintext of an arbitrary length and related details. +#[derive(Clone, Getters)] +pub struct UnverifiedCommitment { + /// A non-empty collection of commitment details for each chunk of the plaintext. + #[getset(get = "pub")] + chunk_commitments: Vec>, +} + +impl UnverifiedCommitment +where + I: IdCollection + Default, + F: Field, +{ + /// Creates a new `UnverifiedCommitment` instance. + /// + /// # Arguments + /// + /// * `chunk_commitments` - A non-empty collection of commitment details for each chunk of the + /// plaintext. + pub fn new(chunk_commitments: Vec>) -> Self { + Self { chunk_commitments } + } + + /// Returns the id of each bit of the plaintext of this commitment. + pub fn ids(&self) -> I { + let iter = self + .chunk_commitments + .iter() + .map(|com| com.ids.clone()) + .collect::>(); + + I::new_from_iter(iter) + } + + /// Returns the length of the plaintext of this commitment. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.chunk_commitments.iter().map(|com| com.ids.len()).sum() + } +} + +/// A yet-unverified commitment details for a single chunk of plaintext. +#[derive(Clone, Getters)] +pub struct UnverifiedChunkCommitment { + /// Hash commitment to the plaintext. + #[getset(get = "pub")] + plaintext_hash: F, + /// Hash commitment to the arithemtic sum of the encodings of the plaintext. + #[getset(get = "pub")] + encoding_sum_hash: F, + /// The id of each bit of the committed plaintext in MSB0 bit order. + #[getset(get = "pub")] + ids: I, +} + +impl UnverifiedChunkCommitment +where + I: IdCollection, + F: Field, +{ + /// Creates a new unverified chunk commitment. + /// + /// # Arguments + /// + /// * `plaintext_hash` - Hash commitment to the plaintext. + /// * `encoding_sum_hash` - Hash commitment to the arithemtic sum of the encodings of the plaintext. + /// * `ids` - The id of each bit of the committed plaintext. + pub fn new(plaintext_hash: F, encoding_sum_hash: F, ids: I) -> Self { + Self { + plaintext_hash, + encoding_sum_hash, + ids, + } + } + + /// Returns the bitlength of the plaintext committed to. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.ids.len() + } +} + +/// A verified commitment to plaintext of an arbitrary length. +#[derive(Clone)] +pub struct VerifiedCommitment { + /// A non-empty collection of commitments for each chunk of the plaintext. + chunk_commitments: Vec>, +} + +impl VerifiedCommitment +where + I: IdCollection + Default, + F: Field, +{ + /// Creates a new instance. + /// + /// # Arguments + /// + /// * `chunk_commitments` - A non-empty collection of commitment details for each chunk of the + /// plaintext. + pub fn new(chunk_commitments: Vec>) -> Self { + Self { chunk_commitments } + } + + /// Returns the length of the plaintext of this commitment. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.chunk_commitments.iter().map(|com| com.ids.len()).sum() + } + + /// Returns a non-empty collection of commitments for each chunk of the plaintext. + pub fn chunk_commitments(&self) -> &Vec> { + &self.chunk_commitments + } +} + +/// A verified commitment for a single chunk of plaintext. +#[derive(Clone, Getters)] +pub struct VerifiedChunkCommitment { + /// Hash commitment to the plaintext. + #[getset(get = "pub")] + plaintext_hash: F, + /// Hash commitment to the arithemtic sum of the encodings of the plaintext. + #[getset(get = "pub")] + encoding_sum_hash: F, + /// The id of each bit of the plaintext. + #[getset(get = "pub")] + ids: I, +} + +impl VerifiedChunkCommitment +where + I: IdCollection, + F: Field, +{ + /// Creates a new `ChunkCommitment` instance. + /// + /// # Arguments + /// + /// * `plaintext_hash` - Hash commitment to the plaintext. + /// * `encoding_sum_hash` - Hash commitment to the arithemtic sum of the encodings of the plaintext. + /// * `ids` - The id of each bit of the committed plaintext. + pub fn new(plaintext_hash: F, encoding_sum_hash: F, ids: I) -> Self { + Self { + plaintext_hash, + encoding_sum_hash, + ids, + } + } + + /// Returns the bitlength of the plaintext committed to. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.ids.len() + } +} + +impl From> for VerifiedCommitment +where + I: IdCollection, + F: Field, +{ + fn from(unverified: UnverifiedCommitment) -> Self { + Self { + chunk_commitments: unverified + .chunk_commitments + .into_iter() + .map(|com| VerifiedChunkCommitment { + plaintext_hash: com.plaintext_hash, + encoding_sum_hash: com.encoding_sum_hash, + ids: com.ids, + }) + .collect::>(), + } + } +} diff --git a/crates/components/authdecode/authdecode-core/src/verifier/error.rs b/crates/components/authdecode/authdecode-core/src/verifier/error.rs new file mode 100644 index 0000000000..cdfbbc77f7 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/verifier/error.rs @@ -0,0 +1,26 @@ +use crate::{backend::traits::Field, id::IdCollection, msgs::MessageError}; + +#[derive(Debug, thiserror::Error)] +pub enum VerifierError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("The prover has provided the wrong number of proofs. Expected {0}. Got {1}.")] + WrongProofCount(usize, usize), + #[error("Proof verification failed with an error: {0}")] + VerificationFailed(String), + #[error(transparent)] + EncodingProviderError(#[from] crate::encodings::EncodingProviderError), +} + +impl From> for VerifierError +where + I: IdCollection, + F: Field, +{ + fn from(err: MessageError) -> Self { + VerifierError::from(std::io::Error::new( + std::io::ErrorKind::InvalidData, + err.to_string(), + )) + } +} diff --git a/crates/components/authdecode/authdecode-core/src/verifier/state.rs b/crates/components/authdecode/authdecode-core/src/verifier/state.rs new file mode 100644 index 0000000000..d416b6d149 --- /dev/null +++ b/crates/components/authdecode/authdecode-core/src/verifier/state.rs @@ -0,0 +1,50 @@ +//! AuthDecode verifier states. + +use crate::{ + backend::traits::Field, + id::IdCollection, + verifier::commitment::{UnverifiedCommitment, VerifiedCommitment}, +}; + +/// The initial state. +pub struct Initialized {} +opaque_debug::implement!(Initialized); + +/// The state after the verifier received the prover's commitment. +pub struct CommitmentReceived { + /// Details pertaining to each commitment. + pub commitments: Vec>, +} +opaque_debug::implement!(CommitmentReceived); + +/// The state after the commitments have been successfully verified. +pub struct VerifiedSuccessfully { + /// Commitments which have been succesfully verified. + pub commitments: Vec>, +} +opaque_debug::implement!(VerifiedSuccessfully); + +#[allow(missing_docs)] +pub trait VerifierState: sealed::Sealed {} + +impl VerifierState for Initialized {} +impl VerifierState for CommitmentReceived +where + I: IdCollection, + F: Field, +{ +} +impl VerifierState for VerifiedSuccessfully {} + +mod sealed { + use crate::{backend::traits::Field, id::IdCollection}; + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::CommitmentReceived + where + I: IdCollection, + F: Field, + { + } + impl Sealed for super::VerifiedSuccessfully {} +} diff --git a/crates/components/authdecode/authdecode/Cargo.toml b/crates/components/authdecode/authdecode/Cargo.toml new file mode 100644 index 0000000000..bfb32c860f --- /dev/null +++ b/crates/components/authdecode/authdecode/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "tlsn-authdecode" +authors = ["TLSNotary Team"] +description = "A 2PC protocol for authenticated decoding of encodings in zk" +keywords = ["tls", "mpc", "2pc"] +categories = ["cryptography"] +license = "MIT OR Apache-2.0" +version = "0.1.0" +edition = "2021" + +[lib] +name = "authdecode" + +[features] +default = [] +tracing = ["dep:tracing"] + +[dependencies] +tlsn-authdecode-core = { workspace = true } +tlsn-utils-aio = { workspace = true } + +futures-util = { workspace = true } +serde = { workspace = true, features = ["derive"] } +tracing = { version = "0.1", optional = true } + +[dev-dependencies] +tlsn-authdecode-core = { workspace = true, features = ["fixtures", "mock"] } + +criterion = { workspace = true, features = ["async_tokio"] } +rstest = { workspace = true } +tokio = { workspace = true, features = [ + "net", + "macros", + "rt", + "rt-multi-thread", +] } + +[[bench]] +name = "halo2" +harness = false \ No newline at end of file diff --git a/crates/components/authdecode/authdecode/benches/halo2.rs b/crates/components/authdecode/authdecode/benches/halo2.rs new file mode 100644 index 0000000000..b3e7350bfc --- /dev/null +++ b/crates/components/authdecode/authdecode/benches/halo2.rs @@ -0,0 +1,55 @@ +//! Benches for running the authdecode protocol with the halo2 backend. + +use authdecode::{Prover, Verifier}; +use authdecode_core::fixtures::{self, commitment_data}; +use criterion::{criterion_group, criterion_main, Criterion}; +use futures_util::StreamExt; +use utils_aio::duplex::MemoryDuplex; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("authdecode"); + group.sample_size(10); + let rt = tokio::runtime::Runtime::new().unwrap(); + + group.bench_function("authdecode_halo2", |b| { + b.to_async(&rt).iter(authdecode_halo2) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +async fn authdecode_halo2() { + let pair = authdecode_core::backend::halo2::fixtures::backend_pair_mock(); + let commitment_data = commitment_data(); + let encoding_provider = fixtures::encoding_provider(); + + let prover = Prover::new(Box::new(pair.0)); + let verifier = Verifier::new(Box::new(pair.1)); + + let (prover_channel, verifier_channel) = MemoryDuplex::new(); + + let (mut prover_sink, _) = prover_channel.split(); + let (_, mut verifier_stream) = verifier_channel.split(); + + let prover = prover + .commit(&mut prover_sink, commitment_data) + .await + .unwrap(); + + let verifier = verifier + .receive_commitments(&mut verifier_stream) + .await + .unwrap(); + + // An encoding provider is instantiated with authenticated full encodings from external context. + let _ = prover + .prove(&mut prover_sink, &encoding_provider) + .await + .unwrap(); + + let _ = verifier + .verify(&mut verifier_stream, &encoding_provider) + .await + .unwrap(); +} diff --git a/crates/components/authdecode/authdecode/src/lib.rs b/crates/components/authdecode/authdecode/src/lib.rs new file mode 100644 index 0000000000..9b6f527527 --- /dev/null +++ b/crates/components/authdecode/authdecode/src/lib.rs @@ -0,0 +1,140 @@ +//! Implementation of the AuthDecode protocol. +//! +//! The protocol performs authenticated decoding of encodings in zero knowledge. +//! +//! One of the use cases of AuthDecode is for the garbled circuits (GC) evaluator to produce a +//! zk-friendly hash commitment to either the GC input or the GC output, where computing such a +//! commitment directly using GC would be prohibitively expensive. +//! +//! The protocol consists of the following steps: +//! 1. The Prover commits to both the plaintext and the arithmetic sum of the active encodings of the +//! bits of the plaintext. (The protocol assumes that the Prover ascertained beforehand that the +//! active encodings are authentic.) +//! 2. The Prover obtains the full encodings of the plaintext bits from some outer context and uses +//! them to create a zk proof, proving that during Step 1. they knew the correct active encodings +//! of the plaintext and also proving that a hash commitment H is an authentic commitment to the +//! plaintext. +//! 3. The Verifier verifies the proof and accepts H as an authentic hash commitment to the plaintext. +//! +//! Important: when using the protocol, you must ensure that the Prover obtains the full encodings +//! from an outer context only **after** they've made a commitment in Step 1. + +mod prover; +mod verifier; + +pub use prover::Prover; +pub use verifier::Verifier; + +#[cfg(test)] +mod tests { + use crate::*; + + use authdecode_core::{ + backend::traits::{Field, ProverBackend, VerifierBackend}, + fixtures, + mock::{MockBitIds, MockEncodingProvider}, + prover::{CommitmentData, ProofGenerated}, + verifier::VerifiedSuccessfully, + }; + use futures_util::StreamExt; + use rstest::*; + use serde::{de::DeserializeOwned, Serialize}; + use std::ops::{Add, Sub}; + use utils_aio::duplex::MemoryDuplex; + + #[fixture] + fn commitment_data() -> Vec> { + fixtures::commitment_data() + } + + #[fixture] + fn encoding_provider() -> MockEncodingProvider { + fixtures::encoding_provider() + } + + // Tests the protocol with a mock backend. + #[rstest] + #[tokio::test] + async fn test_mock_backend( + commitment_data: Vec>, + encoding_provider: MockEncodingProvider, + ) { + run_authdecode( + authdecode_core::backend::mock::backend_pair(), + commitment_data, + encoding_provider, + ) + .await; + } + + // Tests the protocol with a halo2 backend. + #[rstest] + #[tokio::test] + async fn test_halo2_backend( + commitment_data: Vec>, + encoding_provider: MockEncodingProvider, + ) { + run_authdecode( + authdecode_core::backend::halo2::fixtures::backend_pair_mock(), + commitment_data, + encoding_provider, + ) + .await; + } + + // Runs the protocol with the given backends. + // Returns the prover and the verifier in their finalized state. + #[allow(clippy::type_complexity)] + async fn run_authdecode( + pair: ( + impl ProverBackend + 'static, + impl VerifierBackend + 'static, + ), + commitment_data: Vec>, + encoding_provider: MockEncodingProvider, + ) -> ( + Prover, F>, + Verifier, F>, + ) + where + F: Field + + Add + + Sub + + Serialize + + DeserializeOwned + + Clone + + Send + + 'static, + { + let prover = Prover::new(Box::new(pair.0)); + let verifier = Verifier::new(Box::new(pair.1)); + + let (prover_channel, verifier_channel) = MemoryDuplex::new(); + + let (mut prover_sink, _) = prover_channel.split(); + let (_, mut verifier_stream) = verifier_channel.split(); + + let prover = prover + .commit(&mut prover_sink, commitment_data) + .await + .unwrap(); + + let verifier = verifier + .receive_commitments(&mut verifier_stream) + .await + .unwrap(); + + // An encoding provider is instantiated with authenticated full encodings from an external context. + let prover = prover + .prove(&mut prover_sink, &encoding_provider) + .await + .unwrap(); + + let verifier = verifier + .verify(&mut verifier_stream, &encoding_provider) + .await + .unwrap(); + + (prover, verifier) + } +} diff --git a/crates/components/authdecode/authdecode/src/prover.rs b/crates/components/authdecode/authdecode/src/prover.rs new file mode 100644 index 0000000000..1a6cf3bb67 --- /dev/null +++ b/crates/components/authdecode/authdecode/src/prover.rs @@ -0,0 +1,125 @@ +use futures_util::SinkExt; +use std::ops::Add; +use utils_aio::sink::IoSink; + +use authdecode_core::{ + backend::traits::{Field, ProverBackend as Backend}, + encodings::EncodingProvider, + id::IdCollection, + msgs::Message, + prover::{CommitmentData, Committed, Initialized, ProofGenerated, ProverError, ProverState}, + Prover as CoreProver, +}; + +#[cfg(feature = "tracing")] +use tracing::{debug, debug_span, instrument, Instrument}; + +/// Prover in the AuthDecode protocol. +pub struct Prover +where + I: IdCollection, + F: Field + Add, + S: ProverState, +{ + /// The wrapped prover in the AuthDecode protocol. + prover: CoreProver, +} + +impl Prover +where + I: IdCollection, + F: Field + Add, +{ + /// Creates a new prover. + /// + /// # Arguments + /// + /// * `backend` - The zk backend. + pub fn new(backend: Box>) -> Self { + Self { + prover: CoreProver::new(backend), + } + } +} + +impl Prover +where + I: IdCollection, + F: Field + Add, +{ + /// Creates a commitment to each element in the `data_set`. + /// + /// # Arguments + /// + /// * `sink` - The sink for sending messages to the verifier. + /// * `data_set` - The set of commitment data to be committed to. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub async fn commit> + Send + Unpin>( + self, + sink: &mut Si, + data_set: Vec>, + ) -> Result, F>, ProverError> + where + I: IdCollection, + F: Field + Clone + std::ops::Add, + { + let (core_prover, msg) = self.prover.commit(data_set)?; + + sink.send(Message::Commit(msg)).await?; + + Ok(Prover { + prover: core_prover, + }) + } + + /// Creates a commitment to each element in the `data_set` with the provided salts. + /// + /// # Arguments + /// + /// * `sink` - The sink for sending messages to the verifier. + /// * `data_set` - The set of commitment data with salts for each chunk of it. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub async fn commit_with_salt> + Send + Unpin>( + self, + sink: &mut Si, + data_set: Vec<(CommitmentData, Vec)>, + ) -> Result, F>, ProverError> + where + I: IdCollection, + F: Field + Clone + std::ops::Add, + { + let (core_prover, msg) = self.prover.commit_with_salt(data_set)?; + + sink.send(Message::Commit(msg)).await?; + + Ok(Prover { + prover: core_prover, + }) + } +} + +impl Prover, F> +where + I: IdCollection, + F: Field + Clone + std::ops::Sub + std::ops::Add, +{ + /// Generates zk proofs. + /// + /// # Arguments + /// + /// * `sink` - The sink for sending messages to the verifier. + /// * `encoding_provider` - The provider of full encodings for the plaintext committed to + /// earlier. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub async fn prove> + Send + Unpin>( + self, + sink: &mut Si, + encoding_provider: &impl EncodingProvider, + ) -> Result, F>, ProverError> { + let (prover, msg) = self.prover.prove(encoding_provider)?; + + sink.send(Message::Proofs(msg)).await?; + + Ok(Prover { prover }) + } +} diff --git a/crates/components/authdecode/authdecode/src/verifier.rs b/crates/components/authdecode/authdecode/src/verifier.rs new file mode 100644 index 0000000000..fa2520417d --- /dev/null +++ b/crates/components/authdecode/authdecode/src/verifier.rs @@ -0,0 +1,93 @@ +use utils_aio::stream::{ExpectStreamExt, IoStream}; + +use authdecode_core::{ + backend::traits::{Field, VerifierBackend as Backend}, + encodings::EncodingProvider, + id::IdCollection, + msgs::Message, + verifier::{ + CommitmentReceived, Initialized, VerifiedSuccessfully, VerifierError, VerifierState, + }, + Verifier as CoreVerifier, +}; + +#[cfg(feature = "tracing")] +use tracing::{debug, debug_span, instrument, Instrument}; + +/// Verifier in the AuthDecode protocol. +pub struct Verifier +where + I: IdCollection, + F: Field, + S: VerifierState, +{ + /// The wrapped verifier in the AuthDecode protocol. + verifier: CoreVerifier, +} + +impl Verifier +where + I: IdCollection, + F: Field, +{ + /// Creates a new verifier. + /// + /// # Arguments + /// + /// * `backend` - The zk backend. + pub fn new(backend: Box>) -> Self { + Self { + verifier: CoreVerifier::new(backend), + } + } + + /// Receives the commitments and stores them. + /// + /// # Arguments + /// + /// * `stream` - The stream for receiving messages from the prover. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub async fn receive_commitments> + Send + Unpin>( + self, + stream: &mut St, + ) -> Result, F>, VerifierError> { + let commitments = stream + .expect_next() + .await? + .try_into_commit() + .map_err(VerifierError::from)?; + + Ok(Verifier { + verifier: self.verifier.receive_commitments(commitments)?, + }) + } +} + +impl Verifier, F> +where + I: IdCollection, + F: Field + std::ops::Add + std::ops::Sub + Clone, +{ + /// Verifies proofs for the commitments received earlier. + /// + /// # Arguments + /// + /// * `stream` - The stream for receiving messages from the prover. + /// * `encoding_provider` - The provider of full encodings for plaintext being committed to. + #[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] + pub async fn verify> + Send + Unpin>( + self, + stream: &mut St, + encoding_provider: &(impl EncodingProvider + 'static), + ) -> Result, F>, VerifierError> { + let proofs = stream + .expect_next() + .await? + .try_into_proofs() + .map_err(VerifierError::from)?; + + Ok(Verifier { + verifier: self.verifier.verify(proofs, encoding_provider)?, + }) + } +}