Skip to content

Commit

Permalink
refactor: generalize constraint circuit methods
Browse files Browse the repository at this point in the history
Concretely,
- make `b_constant` accept any `B: Into<BFieldElement>`, and
- make `x_constant` accept any `X: Into<XFieldElement>`.
  • Loading branch information
jan-ferdinand committed Mar 3, 2024
1 parent 8aa04b6 commit aecd75f
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 186 deletions.
6 changes: 2 additions & 4 deletions constraint-evaluation-generator/src/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,8 @@ impl Constraints {
#[cfg(test)]
pub(crate) mod tests {
use twenty_first::bfe;
use twenty_first::xfe;

use triton_vm::prelude::BFieldElement;
use triton_vm::prelude::XFieldElement;
use triton_vm::table::challenges::ChallengeId;
use triton_vm::table::constraint_circuit::DualRowIndicator;

Expand All @@ -232,7 +230,7 @@ pub(crate) mod tests {
pub(crate) fn mini_constraints() -> Self {
let circuit_builder = ConstraintCircuitBuilder::new();
let challenge = |c| circuit_builder.challenge(c);
let constant = |c: u32| circuit_builder.b_constant(bfe!(c));
let constant = |c: u32| circuit_builder.b_constant(c);
let base_row = |i| circuit_builder.input(SingleRowIndicator::BaseRow(i));
let ext_row = |i| circuit_builder.input(SingleRowIndicator::ExtRow(i));

Expand Down Expand Up @@ -275,7 +273,7 @@ pub(crate) mod tests {
fn small_transition_constraints() -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
let circuit_builder = ConstraintCircuitBuilder::new();
let challenge = |c| circuit_builder.challenge(c);
let constant = |c: u32| circuit_builder.x_constant(xfe!(c));
let constant = |c: u32| circuit_builder.x_constant(c);

let curr_base_row = |col| circuit_builder.input(DualRowIndicator::CurrentBaseRow(col));
let next_base_row = |col| circuit_builder.input(DualRowIndicator::NextBaseRow(col));
Expand Down
16 changes: 8 additions & 8 deletions triton-vm/src/table/cascade_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ impl ExtCascadeTable {
};
let challenge = |challenge_id: ChallengeId| circuit_builder.challenge(challenge_id);

let one = circuit_builder.b_constant(b_field_element::BFIELD_ONE);
let two = circuit_builder.b_constant(bfe!(2));
let two_pow_8 = circuit_builder.b_constant(bfe!(1 << 8));
let one = || circuit_builder.b_constant(1);
let two = || circuit_builder.b_constant(2);
let two_pow_8 = circuit_builder.b_constant(1 << 8);
let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial());

let is_padding = base_row(IsPadding);
Expand Down Expand Up @@ -170,7 +170,7 @@ impl ExtCascadeTable {
- lookup_arg_default_initial.clone())
* (hash_indeterminate - compressed_row_hash)
- lookup_multiplicity;
let hash_table_log_derivative_is_initialized_correctly = (one.clone() - is_padding.clone())
let hash_table_log_derivative_is_initialized_correctly = (one() - is_padding.clone())
* hash_table_log_derivative_has_accumulated_first_row
+ is_padding.clone() * hash_table_log_derivative_is_default_initial;

Expand All @@ -185,10 +185,10 @@ impl ExtCascadeTable {
(lookup_table_client_log_derivative - lookup_arg_default_initial)
* (lookup_indeterminate.clone() - compressed_row_lo.clone())
* (lookup_indeterminate.clone() - compressed_row_hi.clone())
- two * lookup_indeterminate
- two() * lookup_indeterminate
+ compressed_row_lo
+ compressed_row_hi;
let lookup_table_log_derivative_is_initialized_correctly = (one - is_padding.clone())
let lookup_table_log_derivative_is_initialized_correctly = (one() - is_padding.clone())
* lookup_table_log_derivative_has_accumulated_first_row
+ is_padding * lookup_table_log_derivative_is_default_initial;

Expand All @@ -205,7 +205,7 @@ impl ExtCascadeTable {
circuit_builder.input(BaseRow(col_id.master_base_table_index()))
};

let one = circuit_builder.b_constant(b_field_element::BFIELD_ONE);
let one = circuit_builder.b_constant(1);
let is_padding = base_row(IsPadding);
let is_padding_is_0_or_1 = is_padding.clone() * (one - is_padding);

Expand All @@ -216,7 +216,7 @@ impl ExtCascadeTable {
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
let challenge = |c| circuit_builder.challenge(c);
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);

let current_base_row = |column_idx: CascadeBaseTableColumn| {
circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index()))
Expand Down
52 changes: 29 additions & 23 deletions triton-vm/src/table/constraint_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -954,27 +954,33 @@ impl<II: InputIndicator> ConstraintCircuitBuilder<II> {

/// The unique monad representing the constant value 0.
pub fn zero(&self) -> ConstraintCircuitMonad<II> {
self.b_constant(bfe!(0))
self.b_constant(0)
}

/// The unique monad representing the constant value 1.
pub fn one(&self) -> ConstraintCircuitMonad<II> {
self.b_constant(bfe!(1))
self.b_constant(1)
}

/// The unique monad representing the constant value -1.
pub fn minus_one(&self) -> ConstraintCircuitMonad<II> {
self.b_constant(bfe!(-1))
self.b_constant(-1)
}

/// Create constant leaf node.
pub fn x_constant(&self, xfe: XFieldElement) -> ConstraintCircuitMonad<II> {
self.make_leaf(XConstant(xfe))
/// Leaf node with constant over the [base field][BFieldElement].
pub fn b_constant<B>(&self, bfe: B) -> ConstraintCircuitMonad<II>
where
B: Into<BFieldElement>,
{
self.make_leaf(BConstant(bfe.into()))
}

/// Create constant leaf node.
pub fn b_constant(&self, bfe: BFieldElement) -> ConstraintCircuitMonad<II> {
self.make_leaf(BConstant(bfe))
/// Leaf node with constant over the [extension field][XFieldElement].
pub fn x_constant<X>(&self, xfe: X) -> ConstraintCircuitMonad<II>
where
X: Into<XFieldElement>,
{
self.make_leaf(XConstant(xfe.into()))
}

/// Create deterministic input leaf node.
Expand Down Expand Up @@ -1096,13 +1102,13 @@ mod tests {
5..=9 => circuit_builder.input(DualRowIndicator::NextBaseRow(base_col_index)),
10..=14 => circuit_builder.input(DualRowIndicator::CurrentExtRow(ext_col_index)),
15..=19 => circuit_builder.input(DualRowIndicator::NextExtRow(ext_col_index)),
20..=24 => circuit_builder.b_constant(rng.gen()),
25..=29 => circuit_builder.x_constant(rng.gen()),
20..=24 => circuit_builder.b_constant(rng.gen::<BFieldElement>()),
25..=29 => circuit_builder.x_constant(rng.gen::<XFieldElement>()),
30..=34 => circuit_builder.challenge(random_challenge_id()),
35 => circuit_builder.b_constant(bfe!(0)),
36 => circuit_builder.x_constant(xfe!(0)),
37 => circuit_builder.b_constant(bfe!(1)),
38 => circuit_builder.x_constant(xfe!(1)),
35 => circuit_builder.b_constant(0),
36 => circuit_builder.x_constant(0),
37 => circuit_builder.b_constant(1),
38 => circuit_builder.x_constant(1),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -1147,7 +1153,7 @@ mod tests {
let hash0 = hasher0.finish();
assert_eq!(circuit, circuit);

let zero = circuit.builder.x_constant(0.into());
let zero = circuit.builder.x_constant(0);
let same_circuit = circuit.clone() + zero;
let mut hasher1 = DefaultHasher::new();
same_circuit.hash(&mut hasher1);
Expand Down Expand Up @@ -1204,9 +1210,9 @@ mod tests {
ConstraintCircuitBuilder::new();
let var_0 = circuit_builder.input(DualRowIndicator::CurrentBaseRow(0));
let var_4 = circuit_builder.input(DualRowIndicator::NextBaseRow(4));
let four = circuit_builder.x_constant(4.into());
let one = circuit_builder.x_constant(1.into());
let zero = circuit_builder.x_constant(0.into());
let four = circuit_builder.x_constant(4);
let one = circuit_builder.x_constant(1);
let zero = circuit_builder.x_constant(0);

assert_ne!(var_0, var_4);
assert_ne!(var_0, four);
Expand Down Expand Up @@ -1267,8 +1273,8 @@ mod tests {
fn constant_folding_pbt() {
for _ in 0..200 {
let circuit = random_circuit();
let one = circuit.builder.x_constant(1.into());
let zero = circuit.builder.x_constant(0.into());
let one = circuit.builder.x_constant(1);
let zero = circuit.builder.x_constant(0);

// Verify that constant folding can handle a = a * 1
let copy_0 = deep_copy(&circuit.circuit.borrow());
Expand Down Expand Up @@ -1580,7 +1586,7 @@ mod tests {
let builder = ConstraintCircuitBuilder::new();
let x = |i| builder.input(BaseRow(i));
let y = |i| builder.input(ExtRow(i));
let b_con = |i: u64| builder.b_constant(i.into());
let b_con = |i: u64| builder.b_constant(i);

let constraint_0 = x(0) * x(0) * (x(1) - x(2)) - x(0) * x(2) - b_con(42);
let constraint_1 = x(1) * (x(1) - b_con(5)) * x(2) * (x(2) - b_con(1));
Expand Down Expand Up @@ -2134,7 +2140,7 @@ mod tests {
let builder = ConstraintCircuitBuilder::new();

let x = |i| builder.input(BaseRow(i));
let b_con = |i: u64| builder.b_constant(i.into());
let b_con = |i: u64| builder.b_constant(i);

let sub_tree_0 = x(0) * x(1) * (x(2) - b_con(1)) * x(3) * x(4);
let sub_tree_1 = x(0) * x(1) * (x(2) - b_con(1)) * x(3) * x(5);
Expand Down
22 changes: 11 additions & 11 deletions triton-vm/src/table/hash_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl ExtHashTable {
mid_low: ConstraintCircuitMonad<II>,
lowest: ConstraintCircuitMonad<II>,
) -> ConstraintCircuitMonad<II> {
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);
let montgomery_modulus_inv = circuit_builder.b_constant(MONTGOMERY_MODULUS.inverse());

let sum_of_shifted_limbs = highest * constant(1 << 48)
Expand All @@ -161,7 +161,7 @@ impl ExtHashTable {
circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
let challenge = |c| circuit_builder.challenge(c);
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);

let base_row = |column: HashBaseTableColumn| {
circuit_builder.input(BaseRow(column.master_base_table_index()))
Expand Down Expand Up @@ -338,7 +338,7 @@ impl ExtHashTable {
round_number_to_deselect <= NUM_ROUNDS,
"Round number must be in [0, {NUM_ROUNDS}] but got {round_number_to_deselect}."
);
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);

// To not subtract zero from the first factor: some special casing.
let first_factor = match round_number_to_deselect {
Expand All @@ -358,7 +358,7 @@ impl ExtHashTable {
mode_circuit_node: &ConstraintCircuitMonad<II>,
mode_to_select: HashTableMode,
) -> ConstraintCircuitMonad<II> {
mode_circuit_node.clone() - circuit_builder.b_constant(mode_to_select.into())
mode_circuit_node.clone() - circuit_builder.b_constant(mode_to_select)
}

/// A constraint circuit evaluating to zero if and only if the given `mode_circuit_node` is
Expand All @@ -368,7 +368,7 @@ impl ExtHashTable {
mode_circuit_node: &ConstraintCircuitMonad<II>,
mode_to_deselect: HashTableMode,
) -> ConstraintCircuitMonad<II> {
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);
HashTableMode::iter()
.filter(|&mode| mode != mode_to_deselect)
.map(|mode| mode_circuit_node.clone() - constant(mode.into()))
Expand All @@ -380,7 +380,7 @@ impl ExtHashTable {
current_instruction_node: &ConstraintCircuitMonad<II>,
instruction_to_deselect: Instruction,
) -> ConstraintCircuitMonad<II> {
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);
let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
let relevant_instructions = [Hash, SpongeInit, SpongeAbsorb, SpongeSqueeze];
assert!(relevant_instructions.contains(&instruction_to_deselect));
Expand All @@ -396,7 +396,7 @@ impl ExtHashTable {
circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);
let base_row = |column_id: HashBaseTableColumn| {
circuit_builder.input(BaseRow(column_id.master_base_table_index()))
};
Expand Down Expand Up @@ -617,7 +617,7 @@ impl ExtHashTable {
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
let challenge = |c| circuit_builder.challenge(c);
let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);

let opcode_hash = opcode(Hash);
let opcode_sponge_init = opcode(SpongeInit);
Expand Down Expand Up @@ -1070,7 +1070,7 @@ impl ExtHashTable {
[ConstraintCircuitMonad<DualRowIndicator>; STATE_SIZE],
[ConstraintCircuitMonad<DualRowIndicator>; STATE_SIZE],
) {
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);
let b_constant = |c| circuit_builder.b_constant(c);
let current_base_row = |column_idx: HashBaseTableColumn| {
circuit_builder.input(CurrentBaseRow(column_idx.master_base_table_index()))
Expand Down Expand Up @@ -1221,7 +1221,7 @@ impl ExtHashTable {
) -> ConstraintCircuitMonad<DualRowIndicator> {
let challenge = |c| circuit_builder.challenge(c);
let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
let constant = |c: u32| circuit_builder.b_constant(c.into());
let constant = |c: u32| circuit_builder.b_constant(c);
let next_base_row = |column_idx: HashBaseTableColumn| {
circuit_builder.input(NextBaseRow(column_idx.master_base_table_index()))
};
Expand Down Expand Up @@ -1271,7 +1271,7 @@ impl ExtHashTable {
) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
let challenge = |c| circuit_builder.challenge(c);
let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
let constant = |c: u64| circuit_builder.b_constant(c.into());
let constant = |c: u64| circuit_builder.b_constant(c);
let base_row = |column_idx: HashBaseTableColumn| {
circuit_builder.input(BaseRow(column_idx.master_base_table_index()))
};
Expand Down
19 changes: 9 additions & 10 deletions triton-vm/src/table/jump_stack_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl ExtJumpStackTable {
pub fn transition_constraints(
circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
let one = circuit_builder.b_constant(1u32.into());
let one = || circuit_builder.b_constant(1);
let call_opcode =
circuit_builder.b_constant(Instruction::Call(BFieldElement::default()).opcode_b());
let return_opcode = circuit_builder.b_constant(Instruction::Return.opcode_b());
Expand Down Expand Up @@ -104,13 +104,13 @@ impl ExtJumpStackTable {
// 1. The jump stack pointer jsp increases by 1
// or the jump stack pointer jsp does not change
let jsp_inc_or_stays =
(jsp_next.clone() - jsp.clone() - one.clone()) * (jsp_next.clone() - jsp.clone());
(jsp_next.clone() - jsp.clone() - one()) * (jsp_next.clone() - jsp.clone());

// 2. The jump stack pointer jsp increases by 1
// or current instruction ci is return
// or the jump stack origin jso does not change
let jsp_inc_by_one_or_ci_is_return =
(jsp_next.clone() - jsp.clone() - one.clone()) * (ci.clone() - return_opcode.clone());
(jsp_next.clone() - jsp.clone() - one()) * (ci.clone() - return_opcode.clone());
let jsp_inc_or_jso_stays_or_ci_is_ret =
jsp_inc_by_one_or_ci_is_return.clone() * (jso_next.clone() - jso);

Expand All @@ -124,11 +124,10 @@ impl ExtJumpStackTable {
// or the cycle count clk increases by 1
// or current instruction ci is call
// or current instruction ci is return
let jsp_inc_or_clk_inc_or_ci_call_or_ci_ret =
(jsp_next.clone() - jsp.clone() - one.clone())
* (clk_next.clone() - clk.clone() - one.clone())
* (ci.clone() - call_opcode)
* (ci - return_opcode);
let jsp_inc_or_clk_inc_or_ci_call_or_ci_ret = (jsp_next.clone() - jsp.clone() - one())
* (clk_next.clone() - clk.clone() - one())
* (ci.clone() - call_opcode)
* (ci - return_opcode);

// The running product for the permutation argument `rppa` accumulates one row in each
// row, relative to weights `a`, `b`, `c`, `d`, `e`, and indeterminate `α`.
Expand All @@ -152,8 +151,8 @@ impl ExtJumpStackTable {
let log_derivative_accumulates = (clock_jump_diff_log_derivative_next
- clock_jump_diff_log_derivative)
* (circuit_builder.challenge(ClockJumpDifferenceLookupIndeterminate) - clk_diff)
- one.clone();
let log_derivative_updates_correctly = (jsp_next.clone() - jsp.clone() - one)
- one();
let log_derivative_updates_correctly = (jsp_next.clone() - jsp.clone() - one())
* log_derivative_accumulates
+ (jsp_next - jsp) * log_derivative_remains;

Expand Down
Loading

0 comments on commit aecd75f

Please sign in to comment.