Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Table domain separation #51

Merged
merged 3 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 100 additions & 46 deletions plonk/src/circuit/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ use rayon::prelude::*;

/// The wire type identifier for range gates.
const RANGE_WIRE_ID: usize = 5;
/// The wire type identifier for the key index in a lookup gate
const LOOKUP_KEY_WIRE_ID: usize = 0;
/// The wire type identifiers for the searched pair values in a lookup gate
const LOOKUP_VAL_1_WIRE_ID: usize = 1;
const LOOKUP_VAL_2_WIRE_ID: usize = 2;
/// The wire type identifiers for the pair values in the lookup table
const TABLE_VAL_1_WIRE_ID: usize = 3;
const TABLE_VAL_2_WIRE_ID: usize = 4;

/// Hardcoded parameters for Plonk systems.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -314,29 +322,38 @@ impl<F: FftField> Circuit<F> for PlonkCircuit<F> {
}
// key-value map lookup gates
let mut key_val_table = HashSet::new();
key_val_table.insert((F::zero(), F::zero(), F::zero()));
let mut num_table_elems: u32 = 0;
key_val_table.insert((F::zero(), F::zero(), F::zero(), F::zero()));
let q_lookup_vec = self.q_lookup();
for (gate_id, &q_lookup) in q_lookup_vec.iter().enumerate() {
let q_dom_sep_vec = self.q_dom_sep();
let table_key_vec = self.table_key_vec();
let table_dom_sep_vec = self.table_dom_sep_vec();
// insert table elements
for (gate_id, ((&q_lookup, &table_dom_sep), &table_key)) in q_lookup_vec
.iter()
.zip(table_dom_sep_vec.iter())
.zip(table_key_vec.iter())
.enumerate()
{
if q_lookup != F::zero() {
let key = F::from(num_table_elems);
let val0 = self.witness(self.wire_variable(3, gate_id))?;
let val1 = self.witness(self.wire_variable(4, gate_id))?;
key_val_table.insert((key, val0, val1));
num_table_elems += 1;
let val0 = self.witness(self.wire_variable(TABLE_VAL_1_WIRE_ID, gate_id))?;
let val1 = self.witness(self.wire_variable(TABLE_VAL_2_WIRE_ID, gate_id))?;
key_val_table.insert((table_dom_sep, table_key, val0, val1));
}
}
for (gate_id, &q_lookup) in q_lookup_vec.iter().enumerate() {
// check lookups
for (gate_id, (&q_lookup, &q_dom_sep)) in
q_lookup_vec.iter().zip(q_dom_sep_vec.iter()).enumerate()
{
if q_lookup != F::zero() {
let key = self.witness(self.wire_variable(0, gate_id))?;
let val0 = self.witness(self.wire_variable(1, gate_id))?;
let val1 = self.witness(self.wire_variable(2, gate_id))?;
if !key_val_table.contains(&(key, val0, val1)) {
let key = self.witness(self.wire_variable(LOOKUP_KEY_WIRE_ID, gate_id))?;
let val0 = self.witness(self.wire_variable(LOOKUP_VAL_1_WIRE_ID, gate_id))?;
let val1 = self.witness(self.wire_variable(LOOKUP_VAL_2_WIRE_ID, gate_id))?;
if !key_val_table.contains(&(q_dom_sep, key, val0, val1)) {
return Err(GateCheckFailure(
gate_id,
format!(
"Lookup gate failed: ({}, {}, {}) not in the table",
key, val0, val1
"Lookup gate failed: ({}, {}, {}, {}) not in the table",
q_dom_sep, key, val0, val1
),
)
.into());
Expand Down Expand Up @@ -788,6 +805,21 @@ impl<F: FftField> PlonkCircuit<F> {
fn q_lookup(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_lookup()).collect()
}
// getter for all lookup domain separation selector
#[inline]
fn q_dom_sep(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_dom_sep()).collect()
}
// getter for the vector of table keys
#[inline]
fn table_key_vec(&self) -> Vec<F> {
self.gates.iter().map(|g| g.table_key()).collect()
}
// getter for the vector of table domain separation ids
#[inline]
fn table_dom_sep_vec(&self) -> Vec<F> {
self.gates.iter().map(|g| g.table_dom_sep()).collect()
}
// TODO: (alex) try return reference instead of expensive clone
// getter for all selectors in the following order:
// q_lc, q_mul, q_hash, q_o, q_c, q_ecc, [q_lookup (if support lookup)]
Expand Down Expand Up @@ -1171,24 +1203,45 @@ where
}

fn compute_key_table_polynomial(&self) -> Result<DensePolynomial<F>, PlonkError> {
let key_table = self.compute_key_table()?;
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.table_key_vec()),
))
}

fn compute_table_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, PlonkError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.table_dom_sep_vec()),
zhenfeizhang marked this conversation as resolved.
Show resolved Hide resolved
))
}

fn compute_q_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, PlonkError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&key_table),
domain.ifft(&self.q_dom_sep()),
))
}

fn compute_merged_lookup_table(&self, tau: F) -> Result<Vec<F>, PlonkError> {
let range_table = self.compute_range_table()?;
let key_table = self.compute_key_table()?;
let table_key_vec = self.table_key_vec();
let table_dom_sep_vec = self.table_dom_sep_vec();
let q_lookup_vec = self.q_lookup();

let mut merged_lookup_table = vec![];
for i in 0..self.eval_domain_size()? {
merged_lookup_table.push(self.merged_table_value(
tau,
&range_table,
&key_table,
&table_key_vec,
&table_dom_sep_vec,
&q_lookup_vec,
i,
)?);
Expand Down Expand Up @@ -1230,9 +1283,11 @@ where
let beta_plus_one = F::one() + *beta;
let gamma_mul_beta_plus_one = *gamma * beta_plus_one;
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
for j in 0..(n - 2) {
// compute merged lookup witness value
let lookup_wire_val = self.merged_lookup_wire_value(*tau, j, &q_lookup_vec)?;
let lookup_wire_val =
self.merged_lookup_wire_value(*tau, j, &q_lookup_vec, &q_dom_sep_vec)?;
let table_val = merged_lookup_table[j];
let table_next_val = merged_lookup_table[j + 1];
let h1_val = sorted_vec[j];
Expand Down Expand Up @@ -1281,8 +1336,9 @@ where
// only the first n-1 variables are for lookup
let mut lookup_map = HashMap::<F, usize>::new();
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
for i in 0..(n - 1) {
let elem = self.merged_lookup_wire_value(tau, i, &q_lookup_vec)?;
let elem = self.merged_lookup_wire_value(tau, i, &q_lookup_vec, &q_dom_sep_vec)?;
let n_lookups = lookup_map.entry(elem).or_insert(0);
*n_lookups += 1;
}
Expand Down Expand Up @@ -1328,35 +1384,26 @@ impl<F: PrimeField> PlonkCircuit<F> {
Ok(range_table)
}

#[inline]
// TODO: generalize to arbitrary key sets.
fn compute_key_table(&self) -> Result<Vec<F>, PlonkError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let n = self.eval_domain_size()?;
let mut key_table = vec![F::zero(); n - 1 - self.num_table_elems];
for i in 0..self.num_table_elems {
key_table.push(F::from(i as u32));
}
key_table.push(F::zero());
Ok(key_table)
}

#[inline]
fn merged_table_value(
&self,
tau: F,
range_table: &[F],
key_table: &[F],
table_key_vec: &[F],
table_dom_sep_vec: &[F],
q_lookup_vec: &[F],
i: usize,
) -> Result<F, PlonkError> {
let range_val = range_table[i];
let key_val = key_table[i];
let key_val = table_key_vec[i];
let dom_sep_val = table_dom_sep_vec[i];
let q_lookup_val = q_lookup_vec[i];
let w3_val = self.witness(self.wire_variable(3, i))?;
let w4_val = self.witness(self.wire_variable(4, i))?;
Ok(range_val + q_lookup_val * tau * (key_val + tau * (w3_val + tau * w4_val)))
let table_val_1 = self.witness(self.wire_variable(TABLE_VAL_1_WIRE_ID, i))?;
let table_val_2 = self.witness(self.wire_variable(TABLE_VAL_2_WIRE_ID, i))?;
Ok(range_val
+ q_lookup_val
* tau
* (dom_sep_val + tau * (key_val + tau * (table_val_1 + tau * table_val_2))))
}

#[inline]
Expand All @@ -1365,13 +1412,18 @@ impl<F: PrimeField> PlonkCircuit<F> {
tau: F,
i: usize,
q_lookup_vec: &[F],
q_dom_sep_vec: &[F],
) -> Result<F, PlonkError> {
let w_range_val = self.witness(self.wire_variable(RANGE_WIRE_ID, i))?;
let w_0_val = self.witness(self.wire_variable(0, i))?;
let w_1_val = self.witness(self.wire_variable(1, i))?;
let w_2_val = self.witness(self.wire_variable(2, i))?;
let lookup_key = self.witness(self.wire_variable(LOOKUP_KEY_WIRE_ID, i))?;
let lookup_val_1 = self.witness(self.wire_variable(LOOKUP_VAL_1_WIRE_ID, i))?;
let lookup_val_2 = self.witness(self.wire_variable(LOOKUP_VAL_2_WIRE_ID, i))?;
let q_lookup_val = q_lookup_vec[i];
Ok(w_range_val + q_lookup_val * tau * (w_0_val + tau * (w_1_val + tau * w_2_val)))
let q_dom_sep_val = q_dom_sep_vec[i];
Ok(w_range_val
+ q_lookup_val
* tau
* (q_dom_sep_val + tau * (lookup_key + tau * (lookup_val_1 + tau * lookup_val_2))))
}
}

Expand Down Expand Up @@ -1969,7 +2021,7 @@ pub(crate) mod test {

// Check key table polynomial
let key_table_poly = circuit.compute_key_table_polynomial()?;
let key_table = circuit.compute_key_table()?;
let key_table = circuit.table_key_vec();
check_polynomial(&key_table_poly, &key_table);

// Check sorted vector polynomials
Expand Down Expand Up @@ -2016,8 +2068,10 @@ pub(crate) mod test {
let one_plus_beta = F::one() + beta;
let gamma_mul_one_plus_beta = gamma * one_plus_beta;
let q_lookup_vec = circuit.q_lookup();
let q_dom_sep = circuit.q_dom_sep();
for j in 0..(n - 2) {
let lookup_wire_val = circuit.merged_lookup_wire_value(tau, j, &q_lookup_vec)?;
let lookup_wire_val =
circuit.merged_lookup_wire_value(tau, j, &q_lookup_vec, &q_dom_sep)?;
let table_val = merged_lookup_table[j];
let table_next_val = merged_lookup_table[j + 1];
let h1_val = sorted_vec[j];
Expand Down
3 changes: 0 additions & 3 deletions plonk/src/circuit/customized/ecc/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,6 @@ where

// create circuit
let range_size = F::from((1 << c) as u32);
for &var in decomposed_scalar_vars.iter() {
circuit.range_gate(var, c)?;
}
circuit.decompose_vars_gate(decomposed_scalar_vars.clone(), scalar_var, range_size)?;

Ok(decomposed_scalar_vars)
Expand Down
17 changes: 15 additions & 2 deletions plonk/src/circuit/customized/gates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,13 @@ where

/// An UltraPlonk lookup gate
#[derive(Debug, Clone)]
pub struct LookupGate;
pub struct LookupGate<F: Field> {
pub(crate) q_dom_sep: F,
pub(crate) table_dom_sep: F,
pub(crate) table_key: F,
}

impl<F> Gate<F> for LookupGate
impl<F> Gate<F> for LookupGate<F>
where
F: Field,
{
Expand All @@ -317,4 +321,13 @@ where
fn q_lookup(&self) -> F {
F::one()
}
fn q_dom_sep(&self) -> F {
self.q_dom_sep
}
fn table_key(&self) -> F {
self.table_key
}
fn table_dom_sep(&self) -> F {
self.table_dom_sep
}
}
73 changes: 51 additions & 22 deletions plonk/src/circuit/customized/ultraplonk/lookup_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@ use crate::{
errors::PlonkError,
};
use ark_ff::PrimeField;
use ark_std::{boxed::Box, cmp::max, vec::Vec};
use ark_std::{boxed::Box, cmp::max};

impl<F: PrimeField> PlonkCircuit<F> {
/// Create a table with keys/values
/// [table_id, ..., table_id + n - 1] and
/// [table_vars\[0\], ..., table_vars[n - 1]];
/// [0, ..., n - 1] and
/// [table_vars\[0\], ..., table_vars\[n - 1\]];
/// and create a list of variable tuples to be looked up:
/// [lookup_vars\[0\], ..., lookup_vars[m - 1]];
///
/// **For each variable tuple `(lookup_var.0, lookup_var.1, lookup_var.2)`
/// to be looked up, the index variable `lookup_var.0` is required to be
/// in range [0, n) (either constrained by a range-check gate or other
/// circuits), so that one can't set it out of bounds and thus do a
/// lookup into one of the *other* tables. **
/// [lookup_vars\[0\], ..., lookup_vars\[m - 1\]];
///
/// w.l.o.g we assume n = m as we can pad with dummy tuples when n != m
pub fn create_table_and_lookup_variables(
Expand All @@ -42,24 +36,38 @@ impl<F: PrimeField> PlonkCircuit<F> {
self.check_var_bound(table_var.1)?;
}
let n = max(lookup_vars.len(), table_vars.len());
// update lookup keys for domain separation.
let lookup_keys: Vec<Variable> = lookup_vars
.iter()
.map(|&(key, ..)| self.add_constant(key, &F::from(self.num_table_elems() as u32)))
.collect::<Result<Vec<_>, _>>()?;
let n_gate = self.num_gates();
(*self.table_gate_ids_mut()).push((n_gate, n));
let table_ctr = F::from(self.table_gate_ids_mut().len() as u64);
for i in 0..n {
let (key, val0, val1) = match i < lookup_vars.len() {
true => (lookup_keys[i], lookup_vars[i].1, lookup_vars[i].2),
false => (self.zero(), self.zero(), self.zero()),
let (q_dom_sep, key, val0, val1) = match i < lookup_vars.len() {
true => (
table_ctr,
lookup_vars[i].0,
lookup_vars[i].1,
lookup_vars[i].2,
),
false => (F::zero(), self.zero(), self.zero(), self.zero()),
};
let (table_val0, table_val1) = match i < table_vars.len() {
true => table_vars[i],
false => (self.zero(), self.zero()),
let (table_dom_sep, table_key, table_val0, table_val1) = match i < table_vars.len() {
true => (
table_ctr,
F::from(i as u64),
table_vars[i].0,
table_vars[i].1,
),
false => (F::zero(), F::zero(), self.zero(), self.zero()),
};
let wire_vars = [key, val0, val1, table_val0, table_val1];
self.insert_gate(&wire_vars, Box::new(LookupGate))?;

self.insert_gate(
&wire_vars,
Box::new(LookupGate {
q_dom_sep,
table_dom_sep,
table_key,
}),
)?;
alxiong marked this conversation as resolved.
Show resolved Hide resolved
}
*self.num_table_elems_mut() += n;
Ok(())
Expand Down Expand Up @@ -136,6 +144,27 @@ mod test {
.create_table_and_lookup_variables(&lookup_vars, &bad_table_vars)
.is_err());

// A lookup over a separate table should not satisfy the circuit.
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(4);
let mut rng = test_rng();

let val0 = circuit.create_variable(F::rand(&mut rng))?;
let val1 = circuit.create_variable(F::rand(&mut rng))?;
let table_vars_1 = vec![(val0, val1)];
let val2 = circuit.create_variable(F::rand(&mut rng))?;
let val3 = circuit.create_variable(F::rand(&mut rng))?;
let table_vars_2 = vec![(val2, val3)];
let val2 = circuit.witness(table_vars_2[0].0)?;
let val3 = circuit.witness(table_vars_2[0].1)?;
let val2_var = circuit.create_variable(val2)?;
let val3_var = circuit.create_variable(val3)?;
let lookup_vars_1 = vec![(circuit.zero(), val2_var, val3_var)];

circuit.create_table_and_lookup_variables(&lookup_vars_1, &table_vars_2)?;
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
circuit.create_table_and_lookup_variables(&lookup_vars_1, &table_vars_1)?;
assert!(circuit.check_circuit_satisfiability(&[]).is_err());

Ok(())
}
}
Loading