Skip to content

Commit

Permalink
[feat] Add Poseidon Chip (#114)
Browse files Browse the repository at this point in the history
* Add Poseidon hasher

* Fix test/lint

* Fix nits

* Fix lint

* Fix nits & add comments

* Add prover test

* Fix CI
  • Loading branch information
nyunyunyunyu authored Aug 22, 2023
1 parent 83ca65e commit 7b23747
Show file tree
Hide file tree
Showing 10 changed files with 699 additions and 140 deletions.
2 changes: 1 addition & 1 deletion halo2-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ harness = false

[[example]]
name = "inner_product"
features = ["test-utils"]
required-features = ["test-utils"]
16 changes: 16 additions & 0 deletions halo2-base/src/gates/flex_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ pub trait GateInstructions<F: ScalarField> {
ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0])
}

/// Constrains and returns `out = a + 1`.
///
/// * `ctx`: [Context] to add the constraints to
/// * `a`: [QuantumCell] value
fn inc(&self, ctx: &mut Context<F>, a: impl Into<QuantumCell<F>>) -> AssignedValue<F> {
self.add(ctx, a, Constant(F::ONE))
}

/// Constrains and returns `a + b * (-1) = out`.
///
/// Defines a vertical gate of form | a - b | b | 1 | a |, where (a - b) = out.
Expand All @@ -200,6 +208,14 @@ pub trait GateInstructions<F: ScalarField> {
ctx.get(-4)
}

/// Constrains and returns `out = a - 1`.
///
/// * `ctx`: [Context] to add the constraints to
/// * `a`: [QuantumCell] value
fn dec(&self, ctx: &mut Context<F>, a: impl Into<QuantumCell<F>>) -> AssignedValue<F> {
self.sub(ctx, a, Constant(F::ONE))
}

/// Constrains and returns `a - b * c = out`.
///
/// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out.
Expand Down
12 changes: 12 additions & 0 deletions halo2-base/src/gates/tests/flex_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,24 @@ pub fn test_add(inputs: &[QuantumCell<Fr>]) -> Fr {
base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value())
}

#[test_case(Witness(Fr::from(10))=> Fr::from(11); "inc(): 10 -> 11")]
#[test_case(Witness(Fr::from(1))=> Fr::from(2); "inc(): 1 -> 2")]
pub fn test_inc(input: QuantumCell<Fr>) -> Fr {
base_test().run_gate(|ctx, chip| *chip.inc(ctx, input).value())
}

#[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")]
#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")]
pub fn test_sub(inputs: &[QuantumCell<Fr>]) -> Fr {
base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value())
}

#[test_case(Witness(Fr::from(10))=> Fr::from(9); "dec(): 10 -> 9")]
#[test_case(Witness(Fr::from(1))=> Fr::from(0); "dec(): 1 -> 0")]
pub fn test_dec(input: QuantumCell<Fr>) -> Fr {
base_test().run_gate(|ctx, chip| *chip.dec(ctx, input).value())
}

#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")]
pub fn test_sub_mul(inputs: &[QuantumCell<Fr>]) -> Fr {
base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value())
Expand Down
206 changes: 156 additions & 50 deletions halo2-base/src/poseidon/hasher/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::mem;

use crate::{
gates::GateInstructions,
poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState},
AssignedValue, Context, ScalarField,
safe_types::{RangeInstructions, SafeTypeChip},
utils::BigPrimeField,
AssignedValue, Context,
QuantumCell::Constant,
ScalarField,
};

use getset::Getters;
use num_bigint::BigUint;
use std::{cell::OnceCell, mem};

#[cfg(test)]
mod tests;

Expand All @@ -16,15 +22,142 @@ pub mod spec;
/// Module for poseidon states.
pub mod state;

/// Poseidon hasher. This is stateful.
/// Stateless Poseidon hasher.
pub struct PoseidonHasher<F: ScalarField, const T: usize, const RATE: usize> {
spec: OptimizedPoseidonSpec<F, T, RATE>,
consts: OnceCell<PoseidonHasherConsts<F, T, RATE>>,
}
#[derive(Getters)]
struct PoseidonHasherConsts<F: ScalarField, const T: usize, const RATE: usize> {
#[getset(get = "pub")]
init_state: PoseidonState<F, T, RATE>,
// hash of an empty input("").
#[getset(get = "pub")]
empty_hash: AssignedValue<F>,
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherConsts<F, T, RATE> {
pub fn new(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
spec: &OptimizedPoseidonSpec<F, T, RATE>,
) -> Self {
let init_state = PoseidonState::default(ctx);
let mut state = init_state.clone();
let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec);
Self { init_state, empty_hash }
}
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
/// Create a poseidon hasher from an existing spec.
pub fn new(spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
Self { spec, consts: OnceCell::new() }
}
/// Initialize necessary consts of hasher. Must be called before any computation.
pub fn initialize_consts(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
self.consts.get_or_init(|| PoseidonHasherConsts::<F, T, RATE>::new(ctx, gate, &self.spec));
}

fn empty_hash(&self) -> &AssignedValue<F> {
self.consts.get().unwrap().empty_hash()
}
fn init_state(&self) -> &PoseidonState<F, T, RATE> {
self.consts.get().unwrap().init_state()
}

/// Constrains and returns hash of a witness array with a variable length.
///
/// Assumes `len` is within [usize] and `len <= inputs.len()`.
/// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required.
/// * len: Length of `inputs`.
/// Return hash of `inputs`.
pub fn hash_var_len_array(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
inputs: &[AssignedValue<F>],
len: AssignedValue<F>,
) -> AssignedValue<F>
where
F: BigPrimeField,
{
let max_len = inputs.len();
if max_len == 0 {
return *self.empty_hash();
};

// len <= max_len --> num_of_bits(len) <= num_of_bits(max_len)
let num_bits = (usize::BITS - max_len.leading_zeros()) as usize;
// num_perm = len // RATE + 1, len_last_chunk = len % RATE
let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits);
num_perm = range.gate().inc(ctx, num_perm);

let mut state = self.init_state().clone();
let mut result_state = state.clone();
for (i, chunk) in inputs.chunks(RATE).enumerate() {
let is_last_perm =
range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64)));
let len_chunk = range.gate().select(
ctx,
len_last_chunk,
Constant(F::from(RATE as u64)),
is_last_perm,
);

state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec);
result_state.select(
ctx,
range.gate(),
SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
&state,
);
}
if max_len % RATE == 0 {
let is_last_perm = range.gate().is_equal(
ctx,
num_perm,
Constant(F::from((max_len / RATE + 1) as u64)),
);
let len_chunk = ctx.load_zero();
state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec);
result_state.select(
ctx,
range.gate(),
SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
&state,
);
}
result_state.s[1]
}

/// Constrains and returns hash of a witness array.
///
/// * inputs: An array of [AssignedValue].
/// Return hash of `inputs`.
pub fn hash_fix_len_array(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
inputs: &[AssignedValue<F>],
) -> AssignedValue<F>
where
F: BigPrimeField,
{
let mut state = self.init_state().clone();
fix_len_array_squeeze(ctx, range.gate(), inputs, &mut state, &self.spec)
}
}

/// Poseidon sponge. This is stateful.
pub struct PoseidonSponge<F: ScalarField, const T: usize, const RATE: usize> {
init_state: PoseidonState<F, T, RATE>,
state: PoseidonState<F, T, RATE>,
spec: OptimizedPoseidonSpec<F, T, RATE>,
absorbing: Vec<AssignedValue<F>>,
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonSponge<F, T, RATE> {
/// Create new Poseidon hasher.
pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
ctx: &mut Context<F>,
Expand Down Expand Up @@ -64,53 +197,26 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RAT
gate: &impl GateInstructions<F>,
) -> AssignedValue<F> {
let input_elements = mem::take(&mut self.absorbing);
let exact = input_elements.len() % RATE == 0;

for chunk in input_elements.chunks(RATE) {
self.permutation(ctx, gate, chunk.to_vec());
}
if exact {
self.permutation(ctx, gate, vec![]);
}

self.state.s[1]
fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec)
}
}

fn permutation(
&mut self,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
inputs: Vec<AssignedValue<F>>,
) {
let r_f = self.spec.r_f / 2;
let mds = &self.spec.mds_matrices.mds.0;
let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0;
let sparse_matrices = &self.spec.mds_matrices.sparse_matrices;

// First half of the full round
let constants = &self.spec.constants.start;
self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
for constants in constants.iter().skip(1).take(r_f - 1) {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}
self.state.sbox_full(ctx, gate, constants.last().unwrap());
self.state.apply_mds(ctx, gate, pre_sparse_mds);

// Partial rounds
let constants = &self.spec.constants.partial;
for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
self.state.sbox_part(ctx, gate, constant);
self.state.apply_sparse_mds(ctx, gate, sparse_mds);
}
/// ATTETION: input_elements.len() needs to be fixed at compile time.
fn fix_len_array_squeeze<F: ScalarField, const T: usize, const RATE: usize>(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
input_elements: &[AssignedValue<F>],
state: &mut PoseidonState<F, T, RATE>,
spec: &OptimizedPoseidonSpec<F, T, RATE>,
) -> AssignedValue<F> {
let exact = input_elements.len() % RATE == 0;

// Second half of the full rounds
let constants = &self.spec.constants.end;
for constants in constants.iter() {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}
self.state.sbox_full(ctx, gate, &[F::ZERO; T]);
self.state.apply_mds(ctx, gate, mds);
for chunk in input_elements.chunks(RATE) {
state.permutation(ctx, gate, chunk, None, spec);
}
if exact {
state.permutation(ctx, gate, &[], None, spec);
}

state.s[1]
}
Loading

0 comments on commit 7b23747

Please sign in to comment.