From 7b237476954b9e919cc3173e68dfe94e43556695 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 22 Aug 2023 18:16:11 -0400 Subject: [PATCH] [feat] Add Poseidon Chip (#114) * Add Poseidon hasher * Fix test/lint * Fix nits * Fix lint * Fix nits & add comments * Add prover test * Fix CI --- halo2-base/Cargo.toml | 2 +- halo2-base/src/gates/flex_gate.rs | 16 ++ halo2-base/src/gates/tests/flex_gate.rs | 12 + halo2-base/src/poseidon/hasher/mod.rs | 206 +++++++++++++----- halo2-base/src/poseidon/hasher/state.rs | 143 ++++++++++-- .../poseidon/hasher/tests/compatibility.rs | 22 +- .../src/poseidon/hasher/tests/hasher.rs | 129 +++++++++++ halo2-base/src/poseidon/hasher/tests/mod.rs | 68 +----- halo2-base/src/poseidon/hasher/tests/state.rs | 129 +++++++++++ halo2-base/src/poseidon/mod.rs | 112 ++++++++++ 10 files changed, 699 insertions(+), 140 deletions(-) create mode 100644 halo2-base/src/poseidon/hasher/tests/hasher.rs create mode 100644 halo2-base/src/poseidon/hasher/tests/state.rs diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index cfa1b3ae..68fa66f5 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -73,4 +73,4 @@ harness = false [[example]] name = "inner_product" -features = ["test-utils"] \ No newline at end of file +required-features = ["test-utils"] diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index b89126c2..b456361c 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -180,6 +180,14 @@ pub trait GateInstructions { ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0]) } + /// Constrains and returns `out = a + 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn inc(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.add(ctx, a, Constant(F::ONE)) + } + /// Constrains and returns `a + b * (-1) = out`. /// /// Defines a vertical gate of form | a - b | b | 1 | a |, where (a - b) = out. @@ -200,6 +208,14 @@ pub trait GateInstructions { ctx.get(-4) } + /// Constrains and returns `out = a - 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn dec(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.sub(ctx, a, Constant(F::ONE)) + } + /// Constrains and returns `a - b * c = out`. /// /// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 625e3ff6..ba079c70 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -14,12 +14,24 @@ pub fn test_add(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value()) } +#[test_case(Witness(Fr::from(10))=> Fr::from(11); "inc(): 10 -> 11")] +#[test_case(Witness(Fr::from(1))=> Fr::from(2); "inc(): 1 -> 2")] +pub fn test_inc(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inc(ctx, input).value()) +} + #[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")] #[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")] pub fn test_sub(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value()) } +#[test_case(Witness(Fr::from(10))=> Fr::from(9); "dec(): 10 -> 9")] +#[test_case(Witness(Fr::from(1))=> Fr::from(0); "dec(): 1 -> 0")] +pub fn test_dec(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.dec(ctx, input).value()) +} + #[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")] pub fn test_sub_mul(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value()) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index d7843b1b..f97a3216 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -1,11 +1,17 @@ -use std::mem; - use crate::{ gates::GateInstructions, poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, - AssignedValue, Context, ScalarField, + safe_types::{RangeInstructions, SafeTypeChip}, + utils::BigPrimeField, + AssignedValue, Context, + QuantumCell::Constant, + ScalarField, }; +use getset::Getters; +use num_bigint::BigUint; +use std::{cell::OnceCell, mem}; + #[cfg(test)] mod tests; @@ -16,15 +22,142 @@ pub mod spec; /// Module for poseidon states. pub mod state; -/// Poseidon hasher. This is stateful. +/// Stateless Poseidon hasher. pub struct PoseidonHasher { + spec: OptimizedPoseidonSpec, + consts: OnceCell>, +} +#[derive(Getters)] +struct PoseidonHasherConsts { + #[getset(get = "pub")] + init_state: PoseidonState, + // hash of an empty input(""). + #[getset(get = "pub")] + empty_hash: AssignedValue, +} + +impl PoseidonHasherConsts { + pub fn new( + ctx: &mut Context, + gate: &impl GateInstructions, + spec: &OptimizedPoseidonSpec, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let mut state = init_state.clone(); + let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec); + Self { init_state, empty_hash } + } +} + +impl PoseidonHasher { + /// Create a poseidon hasher from an existing spec. + pub fn new(spec: OptimizedPoseidonSpec) -> Self { + Self { spec, consts: OnceCell::new() } + } + /// Initialize necessary consts of hasher. Must be called before any computation. + pub fn initialize_consts(&mut self, ctx: &mut Context, gate: &impl GateInstructions) { + self.consts.get_or_init(|| PoseidonHasherConsts::::new(ctx, gate, &self.spec)); + } + + fn empty_hash(&self) -> &AssignedValue { + self.consts.get().unwrap().empty_hash() + } + fn init_state(&self) -> &PoseidonState { + self.consts.get().unwrap().init_state() + } + + /// Constrains and returns hash of a witness array with a variable length. + /// + /// Assumes `len` is within [usize] and `len <= inputs.len()`. + /// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required. + /// * len: Length of `inputs`. + /// Return hash of `inputs`. + pub fn hash_var_len_array( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + inputs: &[AssignedValue], + len: AssignedValue, + ) -> AssignedValue + where + F: BigPrimeField, + { + let max_len = inputs.len(); + if max_len == 0 { + return *self.empty_hash(); + }; + + // len <= max_len --> num_of_bits(len) <= num_of_bits(max_len) + let num_bits = (usize::BITS - max_len.leading_zeros()) as usize; + // num_perm = len // RATE + 1, len_last_chunk = len % RATE + let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits); + num_perm = range.gate().inc(ctx, num_perm); + + let mut state = self.init_state().clone(); + let mut result_state = state.clone(); + for (i, chunk) in inputs.chunks(RATE).enumerate() { + let is_last_perm = + range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64))); + let len_chunk = range.gate().select( + ctx, + len_last_chunk, + Constant(F::from(RATE as u64)), + is_last_perm, + ); + + state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + if max_len % RATE == 0 { + let is_last_perm = range.gate().is_equal( + ctx, + num_perm, + Constant(F::from((max_len / RATE + 1) as u64)), + ); + let len_chunk = ctx.load_zero(); + state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + result_state.s[1] + } + + /// Constrains and returns hash of a witness array. + /// + /// * inputs: An array of [AssignedValue]. + /// Return hash of `inputs`. + pub fn hash_fix_len_array( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + inputs: &[AssignedValue], + ) -> AssignedValue + where + F: BigPrimeField, + { + let mut state = self.init_state().clone(); + fix_len_array_squeeze(ctx, range.gate(), inputs, &mut state, &self.spec) + } +} + +/// Poseidon sponge. This is stateful. +pub struct PoseidonSponge { init_state: PoseidonState, state: PoseidonState, spec: OptimizedPoseidonSpec, absorbing: Vec>, } -impl PoseidonHasher { +impl PoseidonSponge { /// Create new Poseidon hasher. pub fn new( ctx: &mut Context, @@ -64,53 +197,26 @@ impl PoseidonHasher, ) -> AssignedValue { let input_elements = mem::take(&mut self.absorbing); - let exact = input_elements.len() % RATE == 0; - - for chunk in input_elements.chunks(RATE) { - self.permutation(ctx, gate, chunk.to_vec()); - } - if exact { - self.permutation(ctx, gate, vec![]); - } - - self.state.s[1] + fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec) } +} - fn permutation( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - inputs: Vec>, - ) { - let r_f = self.spec.r_f / 2; - let mds = &self.spec.mds_matrices.mds.0; - let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0; - let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; - - // First half of the full round - let constants = &self.spec.constants.start; - self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); - for constants in constants.iter().skip(1).take(r_f - 1) { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, constants.last().unwrap()); - self.state.apply_mds(ctx, gate, pre_sparse_mds); - - // Partial rounds - let constants = &self.spec.constants.partial; - for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { - self.state.sbox_part(ctx, gate, constant); - self.state.apply_sparse_mds(ctx, gate, sparse_mds); - } +/// ATTETION: input_elements.len() needs to be fixed at compile time. +fn fix_len_array_squeeze( + ctx: &mut Context, + gate: &impl GateInstructions, + input_elements: &[AssignedValue], + state: &mut PoseidonState, + spec: &OptimizedPoseidonSpec, +) -> AssignedValue { + let exact = input_elements.len() % RATE == 0; - // Second half of the full rounds - let constants = &self.spec.constants.end; - for constants in constants.iter() { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, &[F::ZERO; T]); - self.state.apply_mds(ctx, gate, mds); + for chunk in input_elements.chunks(RATE) { + state.permutation(ctx, gate, chunk, None, spec); } + if exact { + state.permutation(ctx, gate, &[], None, spec); + } + + state.s[1] } diff --git a/halo2-base/src/poseidon/hasher/state.rs b/halo2-base/src/poseidon/hasher/state.rs index 97883cc8..99cb6f21 100644 --- a/halo2-base/src/poseidon/hasher/state.rs +++ b/halo2-base/src/poseidon/hasher/state.rs @@ -1,8 +1,11 @@ use std::iter; +use itertools::Itertools; + use crate::{ gates::GateInstructions, - poseidon::hasher::mds::SparseMDSMatrix, + poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec}, + safe_types::SafeBool, utils::ScalarField, AssignedValue, Context, QuantumCell::{Constant, Existing}, @@ -23,7 +26,75 @@ impl PoseidonState, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + len: Option>, + spec: &OptimizedPoseidonSpec, + ) { + let r_f = spec.r_f / 2; + let mds = &spec.mds_matrices.mds.0; + let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &spec.mds_matrices.sparse_matrices; + + // First half of the full round + let constants = &spec.constants.start; + if let Some(len) = len { + // Note: this doesn't mean `padded_inputs` is 0 padded because there is no constraints on `inputs[len..]` + let padded_inputs: [AssignedValue; RATE] = + core::array::from_fn( + |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() }, + ); + self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]); + } else { + self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); + } + for constants in constants.iter().skip(1).take(r_f - 1) { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, constants.last().unwrap()); + self.apply_mds(ctx, gate, pre_sparse_mds); + + // Partial rounds + let constants = &spec.constants.partial; + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.sbox_part(ctx, gate, constant); + self.apply_sparse_mds(ctx, gate, sparse_mds); + } + + // Second half of the full rounds + let constants = &spec.constants.end; + for constants in constants.iter() { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, &[F::ZERO; T]); + self.apply_mds(ctx, gate, mds); + } + + /// Constrains and set self to a specific state if `selector` is true. + pub fn select( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + selector: SafeBool, + set_to: &Self, + ) { + for i in 0..T { + self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref()); + } + } + + fn x_power5_with_constant( ctx: &mut Context, gate: &impl GateInstructions, x: AssignedValue, @@ -34,7 +105,7 @@ impl PoseidonState, gate: &impl GateInstructions, @@ -45,21 +116,16 @@ impl PoseidonState, - gate: &impl GateInstructions, - constant: &F, - ) { + fn sbox_part(&mut self, ctx: &mut Context, gate: &impl GateInstructions, constant: &F) { let x = &mut self.s[0]; *x = Self::x_power5_with_constant(ctx, gate, *x, constant); } - pub fn absorb_with_pre_constants( + fn absorb_with_pre_constants( &mut self, ctx: &mut Context, gate: &impl GateInstructions, - inputs: Vec>, + inputs: &[AssignedValue], pre_constants: &[F; T], ) { assert!(inputs.len() < T); @@ -94,7 +160,58 @@ impl PoseidonState, + gate: &impl GateInstructions, + inputs: [AssignedValue; RATE], + len: AssignedValue, + pre_constants: &[F; T], + ) { + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // Adding preconstants to the current state. + for (i, pre_const) in pre_constants.iter().enumerate() { + self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const)); + } + + // Generate a mask array where a[i] = i < len for i = 0..RATE. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, RATE); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut inputs_mask = + gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + inputs_mask.reverse(); + + let padded_inputs = inputs + .iter() + .zip(inputs_mask.iter()) + .map(|(input, mask)| gate.mul(ctx, *input, *mask)) + .collect_vec(); + for i in 0..RATE { + // Add all inputs. + self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]); + // Add the extra 1 after inputs. + if i + 2 < T { + self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]); + } + } + // If len == 0, inputs_mask is all 0. Then the extra 1 should be added into s[1]. + let empty_extra_one = gate.not(ctx, inputs_mask[0]); + self.s[1] = gate.add(ctx, self.s[1], empty_extra_one); + } + + fn apply_mds( &mut self, ctx: &mut Context, gate: &impl GateInstructions, @@ -110,7 +227,7 @@ impl PoseidonState, gate: &impl GateInstructions, diff --git a/halo2-base/src/poseidon/hasher/tests/compatibility.rs b/halo2-base/src/poseidon/hasher/tests/compatibility.rs index b8a48003..1b850c91 100644 --- a/halo2-base/src/poseidon/hasher/tests/compatibility.rs +++ b/halo2-base/src/poseidon/hasher/tests/compatibility.rs @@ -3,7 +3,7 @@ use std::{cmp::max, iter::zip}; use crate::{ gates::{builder::GateThreadBuilder, GateChip}, halo2_proofs::halo2curves::bn256::Fr, - poseidon::hasher::PoseidonHasher, + poseidon::hasher::PoseidonSponge, utils::ScalarField, }; use pse_poseidon::Poseidon; @@ -11,7 +11,7 @@ use rand::Rng; // make interleaved calls to absorb and squeeze elements and // check that the result is the same in-circuit and natively -fn poseidon_compatiblity_verification< +fn sponge_compatiblity_verification< F: ScalarField, const T: usize, const RATE: usize, @@ -31,7 +31,7 @@ fn poseidon_compatiblity_verification< // constructing native and in-circuit Poseidon sponges let mut native_sponge = Poseidon::::new(R_F, R_P); // assuming SECURE_MDS = 0 - let mut circuit_sponge = PoseidonHasher::::new::(ctx); + let mut circuit_sponge = PoseidonSponge::::new::(ctx); // preparing to interleave absorptions and squeezings let n_iterations = max(absorptions.len(), squeezings.len()); @@ -85,33 +85,33 @@ fn random_list_usize(len: usize, max: usize) -> Vec { } #[test] -fn test_poseidon_compatibility_squeezing_only() { +fn test_sponge_compatibility_squeezing_only() { let absorptions = Vec::new(); let squeezings = random_list_usize(10, 7); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } #[test] -fn test_poseidon_compatibility_absorbing_only() { +fn test_sponge_compatibility_absorbing_only() { let absorptions = random_nested_list_f(8, 5); let squeezings = Vec::new(); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } #[test] -fn test_poseidon_compatibility_interleaved() { +fn test_sponge_compatibility_interleaved() { let absorptions = random_nested_list_f(10, 5); let squeezings = random_list_usize(7, 10); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } #[test] -fn test_poseidon_compatibility_other_params() { +fn test_sponge_compatibility_other_params() { let absorptions = random_nested_list_f(10, 10); let squeezings = random_list_usize(10, 10); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs new file mode 100644 index 00000000..1af52068 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -0,0 +1,129 @@ +use crate::{ + gates::{builder::GateThreadBuilder, range::RangeInstructions, RangeChip}, + halo2_proofs::halo2curves::bn256::Fr, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, + utils::{testing::base_test, BigPrimeField, ScalarField}, +}; +use pse_poseidon::Poseidon; +use rand::Rng; + +#[derive(Clone)] +struct Payload { + // Represent value of a right-padded witness array with a variable length + pub values: Vec, + // Length of `values`. + pub len: usize, +} + +// check if the results from hasher and native sponge are same. +fn hasher_compatiblity_verification< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, +) where + F: BigPrimeField, +{ + let lookup_bits = 3; + let mut builder = GateThreadBuilder::prover(); + let range = RangeChip::::default(lookup_bits); + + let ctx = builder.main(0); + + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + for payload in payloads { + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(F::from(payload.len as u64)); + let hasher_result = hasher.hash_var_len_array(ctx, &range, &inputs, len); + // 0x1f0db93536afb96e038f897b4fb5548b6aa3144c46893a6459c4b847951a23b4 + assert_eq!(native_result, *hasher_result.value()); + } +} + +fn random_payload(max_len: usize, len: usize, max_value: usize) -> Payload { + assert!(len <= max_len); + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len } +} + +fn random_payload_without_len(max_len: usize, max_value: usize) -> Payload { + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len: rng.gen_range(0..=max_len) } +} + +#[test] +fn test_poseidon_hasher_compatiblity() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // max_len = 0 + random_payload(0, 0, usize::MAX), + // max_len % RATE == 0 && len = 0 + random_payload(RATE * 2, 0, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2, RATE, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5, RATE * 2 + 1, usize::MAX), + // max_len % RATE == 0 && len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + // len % RATE != 0 && len = 0 + random_payload(RATE * 2 + 1, 0, usize::MAX), + random_payload(RATE * 5 + 1, 0, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2 + 1, RATE, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5 + 1, RATE * 2 + 1, usize::MAX), + // len % RATE != 0 && len = max_len + random_payload(RATE * 2 + 1, RATE * 2 + 1, usize::MAX), + random_payload(RATE * 5 + 1, RATE * 5 + 1, usize::MAX), + ]; + hasher_compatiblity_verification::(payloads); + } +} + +#[test] +fn test_poseidon_hasher_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + const R_F: usize = 8; + const R_P: usize = 57; + + let max_lens = vec![0, RATE * 2, RATE * 5, RATE * 2 + 1, RATE * 5 + 1]; + for max_len in max_lens { + let init_input = random_payload_without_len(max_len, usize::MAX); + let logic_input = random_payload_without_len(max_len, usize::MAX); + base_test().k(12).bench_builder(init_input, logic_input, |builder, range, payload| { + let ctx = builder.main(0); + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + hasher.hash_var_len_array(ctx, range, &inputs, len); + }); + } + } +} diff --git a/halo2-base/src/poseidon/hasher/tests/mod.rs b/halo2-base/src/poseidon/hasher/tests/mod.rs index 7deefefc..a734f7d0 100644 --- a/halo2-base/src/poseidon/hasher/tests/mod.rs +++ b/halo2-base/src/poseidon/hasher/tests/mod.rs @@ -1,12 +1,11 @@ use super::*; -use crate::{ - gates::{builder::GateThreadBuilder, GateChip}, - halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, -}; +use crate::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; use itertools::Itertools; mod compatibility; +mod hasher; +mod state; #[test] fn test_mds() { @@ -36,66 +35,5 @@ fn test_mds() { } } -#[test] -fn test_poseidon_against_test_vectors() { - let mut builder = GateThreadBuilder::prover(); - let gate = GateChip::::default(); - let ctx = builder.main(0); - - // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt - // poseidonperm_x5_254_3 - { - const R_F: usize = 8; - const R_P: usize = 57; - const T: usize = 3; - const RATE: usize = 2; - - let mut hasher = PoseidonHasher::::new::(ctx); - - let state = [0u64, 1, 2]; - hasher.state = - PoseidonState:: { s: state.map(|v| ctx.load_constant(Fr::from(v))) }; - let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); - hasher.permutation(ctx, &gate, inputs); // avoid padding - let state_0 = hasher.state.s; - let expected = [ - "7853200120776062878684798364095072458815029376092732009249414926327459813530", - "7142104613055408817911962100316808866448378443474503659992478482890339429929", - "6549537674122432311777789598043107870002137484850126429160507761192163713804", - ]; - for (word, expected) in state_0.into_iter().zip(expected.iter()) { - assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); - } - } - - // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt - // poseidonperm_x5_254_5 - { - const R_F: usize = 8; - const R_P: usize = 60; - const T: usize = 5; - const RATE: usize = 4; - - let mut hasher = PoseidonHasher::::new::(ctx); - - let state = [0u64, 1, 2, 3, 4]; - hasher.state = - PoseidonState:: { s: state.map(|v| ctx.load_constant(Fr::from(v))) }; - let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); - hasher.permutation(ctx, &gate, inputs); - let state_0 = hasher.state.s; - let expected = [ - "18821383157269793795438455681495246036402687001665670618754263018637548127333", - "7817711165059374331357136443537800893307845083525445872661165200086166013245", - "16733335996448830230979566039396561240864200624113062088822991822580465420551", - "6644334865470350789317807668685953492649391266180911382577082600917830417726", - "3372108894677221197912083238087960099443657816445944159266857514496320565191", - ]; - for (word, expected) in state_0.into_iter().zip(expected.iter()) { - assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); - } - } -} - // TODO: test clear()/squeeze(). // TODO: test constraints actually work. diff --git a/halo2-base/src/poseidon/hasher/tests/state.rs b/halo2-base/src/poseidon/hasher/tests/state.rs new file mode 100644 index 00000000..a6c40268 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/state.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::{ + gates::{builder::GateThreadBuilder, GateChip}, + halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, +}; + +#[test] +fn test_fix_permutation_against_test_vectors() { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::::default(); + let ctx = builder.main(0); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} + +#[test] +fn test_var_permutation_against_test_vectors() { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::::default(); + let ctx = builder.main(0); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs index 31628389..9e182c53 100644 --- a/halo2-base/src/poseidon/mod.rs +++ b/halo2-base/src/poseidon/mod.rs @@ -1,2 +1,114 @@ +use crate::{ + gates::RangeChip, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, + safe_types::{FixLenBytes, RangeInstructions, VarLenBytes, VarLenBytesVec}, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; + +use itertools::Itertools; + /// Module for Poseidon hasher pub mod hasher; + +/// Chip for Poseidon hash. +pub struct PoseidonChip<'a, F: ScalarField, const T: usize, const RATE: usize> { + range_chip: &'a RangeChip, + hasher: PoseidonHasher, +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonChip<'a, F, T, RATE> { + /// Create a new PoseidonChip. + pub fn new( + ctx: &mut Context, + spec: OptimizedPoseidonSpec, + range_chip: &'a RangeChip, + ) -> Self { + let mut hasher = PoseidonHasher::new(spec); + hasher.initialize_consts(ctx, range_chip.gate()); + Self { range_chip, hasher } + } +} + +/// Trait for Poseidon instructions +pub trait PoseidonInstructions { + /// Return hash of a [VarLenBytes] + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [VarLenBytesVec] + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [FixLenBytes] + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonInstructions + for PoseidonChip<'a, F, T, RATE> +{ + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + *inputs_len, + ) + } + + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + &inputs.bytes().iter().map(|sb| *sb.as_ref()).collect_vec(), + *inputs_len, + ) + } + + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + self.hasher.hash_fix_len_array( + ctx, + self.range_chip, + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + ) + } +}