Skip to content

Commit

Permalink
Add constants for interactions. (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware authored Nov 5, 2024
1 parent 5d388c1 commit 8957dea
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 40 deletions.
12 changes: 6 additions & 6 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ mod tests {
use num_traits::One;

use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::constraint_framework::{EvalAtRow, FrameworkEval, ORIGINAL_TRACE_IDX};
use crate::core::fields::FieldExpOps;
#[test]
fn test_expr_eval() {
Expand All @@ -211,30 +211,30 @@ mod tests {
Box::new(Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Col(ColumnExpr {
interaction: 0,
interaction: ORIGINAL_TRACE_IDX,
idx: 0,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 0,
interaction: ORIGINAL_TRACE_IDX,
idx: 1,
offset: 0
}))
)),
Box::new(Expr::Col(ColumnExpr {
interaction: 0,
interaction: ORIGINAL_TRACE_IDX,
idx: 2,
offset: 0
}))
)),
Box::new(Expr::Inv(Box::new(Expr::Add(
Box::new(Expr::Col(ColumnExpr {
interaction: 0,
interaction: ORIGINAL_TRACE_IDX,
idx: 0,
offset: 0
})),
Box::new(Expr::Col(ColumnExpr {
interaction: 0,
interaction: ORIGINAL_TRACE_IDX,
idx: 1,
offset: 0
}))
Expand Down
10 changes: 7 additions & 3 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ mod tests {
use num_traits::One;

use super::{LogupAtRow, LookupElements};
use crate::constraint_framework::InfoEvaluator;
use crate::constraint_framework::{InfoEvaluator, INTERACTION_TRACE_IDX};
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -308,8 +308,12 @@ mod tests {
#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup =
LogupAtRow::<InfoEvaluator>::new(1, SecureField::one(), None, BaseField::one());
let mut logup = LogupAtRow::<InfoEvaluator>::new(
INTERACTION_TRACE_IDX,
SecureField::one(),
None,
BaseField::one(),
);
logup.write_frac(
&mut InfoEvaluator::default(),
Fraction::new(SecureField::one(), SecureField::one()),
Expand Down
6 changes: 5 additions & 1 deletion crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;

pub const ORIGINAL_TRACE_IDX: usize = 0;
pub const INTERACTION_TRACE_IDX: usize = 1;
pub const PREPROCESSED_TRACE_IDX: usize = 2;

/// A trait for evaluating expressions at some point or row.
pub trait EvalAtRow {
// TODO(Ohad): Use a better trait for these, like 'Algebra' or something.
Expand Down Expand Up @@ -67,7 +71,7 @@ pub trait EvalAtRow {

/// Returns the next mask value for the first interaction at offset 0.
fn next_trace_mask(&mut self) -> Self::F {
let [mask_item] = self.next_interaction_mask(0, [0]);
let [mask_item] = self.next_interaction_mask(ORIGINAL_TRACE_IDX, [0]);
mask_item
}

Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/examples/blake/round/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tracing::{span, Level};

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::ORIGINAL_TRACE_IDX;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down Expand Up @@ -37,7 +38,7 @@ pub struct TraceGenerator {
impl TraceGenerator {
fn new(log_size: u32) -> Self {
assert!(log_size >= LOG_N_LANES);
let trace = (0..blake_round_info().mask_offsets[0].len())
let trace = (0..blake_round_info().mask_offsets[ORIGINAL_TRACE_IDX].len())
.map(|_| unsafe { Col::<SimdBackend, BaseField>::uninitialized(1 << log_size) })
.collect_vec();
Self {
Expand Down
10 changes: 7 additions & 3 deletions crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ use num_traits::Zero;

use super::{BlakeXorElements, N_ROUND_INPUT_FELTS};
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, PREPROCESSED_TRACE_IDX,
};
use crate::core::fields::qm31::SecureField;

pub type BlakeRoundComponent = FrameworkComponent<BlakeRoundEval>;

pub type RoundElements = LookupElements<N_ROUND_INPUT_FELTS>;

use crate::constraint_framework::INTERACTION_TRACE_IDX;

pub struct BlakeRoundEval {
pub log_size: u32,
pub xor_lookup_elements: BlakeXorElements,
Expand All @@ -28,12 +32,12 @@ impl FrameworkEval for BlakeRoundEval {
self.log_size + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let [is_first] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let blake_eval = constraints::BlakeRoundEval {
eval,
xor_lookup_elements: &self.xor_lookup_elements,
round_lookup_elements: &self.round_lookup_elements,
logup: LogupAtRow::new(1, self.total_sum, None, is_first),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, self.total_sum, None, is_first),
};
blake_eval.eval()
}
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/examples/blake/scheduler/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tracing::{span, Level};

use super::{blake_scheduler_info, BlakeElements};
use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::ORIGINAL_TRACE_IDX;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down Expand Up @@ -54,7 +55,7 @@ pub fn gen_trace(
let mut lookup_data = BlakeSchedulerLookupData::new(log_size);
let mut round_inputs = Vec::with_capacity(inputs.len() * N_ROUNDS);

let mut trace = (0..blake_scheduler_info().mask_offsets[0].len())
let mut trace = (0..blake_scheduler_info().mask_offsets[ORIGINAL_TRACE_IDX].len())
.map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) })
.collect_vec();

Expand Down
9 changes: 6 additions & 3 deletions crates/prover/src/examples/blake/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use num_traits::Zero;
use super::round::RoundElements;
use super::N_ROUND_INPUT_FELTS;
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, INTERACTION_TRACE_IDX,
PREPROCESSED_TRACE_IDX,
};
use crate::core::fields::qm31::SecureField;

pub type BlakeSchedulerComponent = FrameworkComponent<BlakeSchedulerEval>;
Expand All @@ -29,12 +32,12 @@ impl FrameworkEval for BlakeSchedulerEval {
self.log_size + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let [is_first] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
eval_blake_scheduler_constraints(
&mut eval,
&self.blake_lookup_elements,
&self.round_lookup_elements,
LogupAtRow::new(1, self.total_sum, None, is_first),
LogupAtRow::new(INTERACTION_TRACE_IDX, self.total_sum, None, is_first),
);
eval
}
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/examples/blake/xor_table/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use itertools::Itertools;

use super::{limb_bits, XorElements};
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::constraint_framework::{EvalAtRow, PREPROCESSED_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::lookups::utils::Fraction;

Expand All @@ -19,9 +19,9 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
// al, bl are the constant columns for the inputs: All pairs of elements in [0,
// 2^LIMB_BITS).
// cl is the constant column for the xor: al ^ bl.
let [al] = self.eval.next_interaction_mask(2, [0]);
let [bl] = self.eval.next_interaction_mask(2, [0]);
let [cl] = self.eval.next_interaction_mask(2, [0]);
let [al] = self.eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let [bl] = self.eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let [cl] = self.eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);

let frac_chunks = (0..(1 << (2 * EXPAND_BITS)))
.map(|i| {
Expand Down
9 changes: 6 additions & 3 deletions crates/prover/src/examples/blake/xor_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use num_traits::Zero;
pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace};

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, INTERACTION_TRACE_IDX,
PREPROCESSED_TRACE_IDX,
};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::Column;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -103,11 +106,11 @@ impl<const ELEM_BITS: u32, const EXPAND_BITS: u32> FrameworkEval
column_bits::<ELEM_BITS, EXPAND_BITS>() + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let [is_first] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> {
eval,
lookup_elements: &self.lookup_elements,
logup: LogupAtRow::new(1, self.claimed_sum, None, is_first),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, self.claimed_sum, None, is_first),
};
xor_eval.eval()
}
Expand Down
18 changes: 12 additions & 6 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::constraint_framework::logup::{
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::{
assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
INTERACTION_TRACE_IDX, PREPROCESSED_TRACE_IDX,
};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand Down Expand Up @@ -48,15 +49,20 @@ impl FrameworkEval for PlonkEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup = LogupAtRow::<_>::new(1, self.total_sum, Some(self.claimed_sum), is_first);
let [is_first] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let mut logup = LogupAtRow::<_>::new(
INTERACTION_TRACE_IDX,
self.total_sum,
Some(self.claimed_sum),
is_first,
);

let [a_wire] = eval.next_interaction_mask(2, [0]);
let [b_wire] = eval.next_interaction_mask(2, [0]);
let [a_wire] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let [b_wire] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
// Note: c_wire could also be implicit: (self.eval.point() - M31_CIRCLE_GEN.into_ef()).x.
// A constant column is easier though.
let [c_wire] = eval.next_interaction_mask(2, [0]);
let [op] = eval.next_interaction_mask(2, [0]);
let [c_wire] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let [op] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);

let mult = eval.next_trace_mask();
let a_val = eval.next_trace_mask();
Expand Down
15 changes: 9 additions & 6 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use tracing::{info, span, Level};
use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, INTERACTION_TRACE_IDX,
PREPROCESSED_TRACE_IDX,
};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
Expand Down Expand Up @@ -60,8 +61,8 @@ impl FrameworkEval for PoseidonEval {
self.log_n_rows + LOG_EXPAND
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let logup = LogupAtRow::new(1, self.total_sum, None, is_first);
let [is_first] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
let logup = LogupAtRow::new(INTERACTION_TRACE_IDX, self.total_sum, None, is_first);
eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements);
eval
}
Expand Down Expand Up @@ -401,7 +402,9 @@ mod tests {

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::constraint_framework::{
assert_constraints, EvalAtRow, INTERACTION_TRACE_IDX, PREPROCESSED_TRACE_IDX,
};
use crate::core::air::Component;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -480,10 +483,10 @@ mod tests {
let trace_polys =
traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec());
assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| {
let [is_first] = eval.next_interaction_mask(2, [0]);
let [is_first] = eval.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
eval_poseidon_constraints(
&mut eval,
LogupAtRow::new(1, total_sum, None, is_first),
LogupAtRow::new(INTERACTION_TRACE_IDX, total_sum, None, is_first),
&lookup_elements,
);
});
Expand Down
12 changes: 9 additions & 3 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use num_traits::{One, Zero};

use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, INTERACTION_TRACE_IDX,
};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
Expand Down Expand Up @@ -40,8 +42,12 @@ impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE>
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup: LogupAtRow<E> =
LogupAtRow::new(1, self.total_sum, Some(self.claimed_sum), is_first);
let mut logup: LogupAtRow<E> = LogupAtRow::new(
INTERACTION_TRACE_IDX,
self.total_sum,
Some(self.claimed_sum),
is_first,
);

let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask());
let input_denom: E::EF = self.lookup_elements.combine(&input_state);
Expand Down

1 comment on commit 8957dea

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 8957dea Previous: f6214d1 Ratio
merkle throughput/simd merkle 30116404 ns/iter (± 474665) 14690867 ns/iter (± 434150) 2.05

This comment was automatically generated by workflow using github-action-benchmark.

CC: @shaharsamocha7

Please sign in to comment.