Skip to content

Commit

Permalink
refactor!: merge result_evaluate back to first_round_evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 4, 2024
1 parent dc52220 commit 31a811e
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 89 deletions.
8 changes: 3 additions & 5 deletions crates/proof-of-sql/src/sql/proof/proof_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,14 @@ pub trait ProofPlan: Debug + Send + Sync + ProverEvaluate {

#[enum_dispatch::enum_dispatch(DynProofPlan)]
pub trait ProverEvaluate {
/// Evaluate the query and return the result.
fn result_evaluate<'a, S: Scalar>(
/// Evaluate the query, modify `FirstRoundBuilder` and return the result.
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>);

/// Evaluate the query and modify `FirstRoundBuilder` to form the query's proof.
fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder);

/// Evaluate the query and modify `FinalRoundBuilder` to store an intermediate representation
/// of the query result and track all the components needed to form the query's proof.
///
Expand Down
9 changes: 4 additions & 5 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,12 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
})
.collect();

// Evaluate query result
let (query_result, one_evaluation_lengths) = expr.result_evaluate(&alloc, &table_map);
// Prover First Round: Evaluate the query && get the right number of post result challenges
let mut first_round_builder = FirstRoundBuilder::new();
let (query_result, one_evaluation_lengths) =
expr.first_round_evaluate(&mut first_round_builder, &alloc, &table_map);
let provable_result = query_result.into();

// Prover First Round
let mut first_round_builder = FirstRoundBuilder::new();
expr.first_round_evaluate(&mut first_round_builder);
let range_length = one_evaluation_lengths
.iter()
.copied()
Expand Down
23 changes: 9 additions & 14 deletions crates/proof-of-sql/src/sql/proof/query_proof_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ impl Default for TrivialTestProofPlan {
}
}
impl ProverEvaluate for TrivialTestProofPlan {
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand All @@ -54,8 +55,6 @@ impl ProverEvaluate for TrivialTestProofPlan {
)
}

fn first_round_evaluate(&self, _builder: &mut FirstRoundBuilder) {}

fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
Expand Down Expand Up @@ -225,16 +224,15 @@ impl Default for SquareTestProofPlan {
}
}
impl ProverEvaluate for SquareTestProofPlan {
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
(table([borrowed_bigint("a1", self.res, alloc)]), vec![2])
}

fn first_round_evaluate(&self, _builder: &mut FirstRoundBuilder) {}

fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
Expand Down Expand Up @@ -409,16 +407,15 @@ impl Default for DoubleSquareTestProofPlan {
}
}
impl ProverEvaluate for DoubleSquareTestProofPlan {
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
(table([borrowed_bigint("a1", self.res, alloc)]), vec![2])
}

fn first_round_evaluate(&self, _builder: &mut FirstRoundBuilder) {}

fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
Expand Down Expand Up @@ -623,16 +620,14 @@ fn verify_fails_the_result_doesnt_satisfy_an_intermediate_equation() {
#[derive(Debug, Serialize)]
struct ChallengeTestProofPlan {}
impl ProverEvaluate for ChallengeTestProofPlan {
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
(table([borrowed_bigint("a1", [9, 25], alloc)]), vec![2])
}

fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) {
builder.request_post_result_challenges(2);
(table([borrowed_bigint("a1", [9, 25], alloc)]), vec![2])
}

fn final_round_evaluate<'a, S: Scalar>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ pub(super) struct EmptyTestQueryExpr {
pub(super) columns: usize,
}
impl ProverEvaluate for EmptyTestQueryExpr {
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand All @@ -40,7 +41,7 @@ impl ProverEvaluate for EmptyTestQueryExpr {
vec![self.length],
)
}
fn first_round_evaluate(&self, _builder: &mut FirstRoundBuilder) {}

fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
Expand Down
5 changes: 2 additions & 3 deletions crates/proof-of-sql/src/sql/proof_plans/empty_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ impl ProofPlan for EmptyExec {

impl ProverEvaluate for EmptyExec {
#[tracing::instrument(name = "EmptyExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
_alloc: &'a Bump,
_table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand All @@ -82,8 +83,6 @@ impl ProverEvaluate for EmptyExec {
)
}

fn first_round_evaluate(&self, _builder: &mut FirstRoundBuilder) {}

#[tracing::instrument(name = "EmptyExec::final_round_evaluate", level = "debug", skip_all)]
#[allow(unused_variables)]
fn final_round_evaluate<'a, S: Scalar>(
Expand Down
8 changes: 3 additions & 5 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ pub type FilterExec = OstensibleFilterExec<HonestProver>;

impl ProverEvaluate for FilterExec {
#[tracing::instrument(name = "FilterExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand Down Expand Up @@ -186,11 +187,8 @@ impl ProverEvaluate for FilterExec {
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator");
(res, vec![output_length])
}

fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) {
builder.request_post_result_challenges(2);
(res, vec![output_length])
}

#[tracing::instrument(name = "FilterExec::final_round_evaluate", level = "debug", skip_all)]
Expand Down
56 changes: 34 additions & 22 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::{
},
sql::{
proof::{
exercise_verification, ProofPlan, ProvableQueryResult, ProverEvaluate,
VerifiableQueryResult,
exercise_verification, FirstRoundBuilder, ProofPlan, ProvableQueryResult,
ProverEvaluate, VerifiableQueryResult,
},
proof_exprs::{test_utility::*, ColumnExpr, DynProofExpr, LiteralExpr, TableExpr},
},
Expand Down Expand Up @@ -177,7 +177,7 @@ fn we_can_prove_and_get_the_correct_result_from_a_basic_filter() {
}

#[test]
fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_result_evaluate() {
fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_first_round_evaluate() {
let alloc = Bump::new();
let data = table([
borrowed_bigint("a", [0; 0], &alloc),
Expand Down Expand Up @@ -207,10 +207,13 @@ fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_result
ColumnType::Decimal75(Precision::new(75).unwrap(), 0),
),
];
let res: OwnedTable<Curve25519Scalar> =
ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map).0)
.to_owned_table(fields)
.unwrap();
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
.to_owned_table(fields)
.unwrap();
let expected: OwnedTable<Curve25519Scalar> = owned_table([
bigint("b", [0; 0]),
int128("c", [0; 0]),
Expand All @@ -222,7 +225,7 @@ fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_result
}

#[test]
fn we_can_get_an_empty_result_from_a_basic_filter_using_result_evaluate() {
fn we_can_get_an_empty_result_from_a_basic_filter_using_first_round_evaluate() {
let alloc = Bump::new();
let data = table([
borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc),
Expand Down Expand Up @@ -252,10 +255,13 @@ fn we_can_get_an_empty_result_from_a_basic_filter_using_result_evaluate() {
ColumnType::Decimal75(Precision::new(1).unwrap(), 0),
),
];
let res: OwnedTable<Curve25519Scalar> =
ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map).0)
.to_owned_table(fields)
.unwrap();
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
.to_owned_table(fields)
.unwrap();
let expected: OwnedTable<Curve25519Scalar> = owned_table([
bigint("b", [0; 0]),
int128("c", [0; 0]),
Expand All @@ -267,7 +273,7 @@ fn we_can_get_an_empty_result_from_a_basic_filter_using_result_evaluate() {
}

#[test]
fn we_can_get_no_columns_from_a_basic_filter_with_no_selected_columns_using_result_evaluate() {
fn we_can_get_no_columns_from_a_basic_filter_with_no_selected_columns_using_first_round_evaluate() {
let alloc = Bump::new();
let data = table([
borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc),
Expand All @@ -285,16 +291,19 @@ fn we_can_get_no_columns_from_a_basic_filter_with_no_selected_columns_using_resu
let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(5));
let expr = filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause);
let fields = &[];
let res: OwnedTable<Curve25519Scalar> =
ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map).0)
.to_owned_table(fields)
.unwrap();
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
.to_owned_table(fields)
.unwrap();
let expected = OwnedTable::try_new(IndexMap::default()).unwrap();
assert_eq!(res, expected);
}

#[test]
fn we_can_get_the_correct_result_from_a_basic_filter_using_result_evaluate() {
fn we_can_get_the_correct_result_from_a_basic_filter_using_first_round_evaluate() {
let alloc = Bump::new();
let data = table([
borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc),
Expand Down Expand Up @@ -324,10 +333,13 @@ fn we_can_get_the_correct_result_from_a_basic_filter_using_result_evaluate() {
ColumnType::Decimal75(Precision::new(1).unwrap(), 0),
),
];
let res: OwnedTable<Curve25519Scalar> =
ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map).0)
.to_owned_table(fields)
.unwrap();
let first_round_builder = &mut FirstRoundBuilder::new();
let res: OwnedTable<Curve25519Scalar> = ProvableQueryResult::from(
expr.first_round_evaluate(first_round_builder, &alloc, &table_map)
.0,
)
.to_owned_table(fields)
.unwrap();
let expected: OwnedTable<Curve25519Scalar> = owned_table([
bigint("b", [3, 5]),
int128("c", [3, 5]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ impl ProverEvaluate for DishonestFilterExec {
level = "debug",
skip_all
)]
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand Down Expand Up @@ -65,11 +66,8 @@ impl ProverEvaluate for DishonestFilterExec {
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator");
(res, vec![output_length])
}

fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) {
builder.request_post_result_challenges(2);
(res, vec![output_length])
}

#[tracing::instrument(
Expand Down
8 changes: 3 additions & 5 deletions crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,9 @@ impl ProofPlan for GroupByExec {

impl ProverEvaluate for GroupByExec {
#[tracing::instrument(name = "GroupByExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand Down Expand Up @@ -255,11 +256,8 @@ impl ProverEvaluate for GroupByExec {
),
)
.expect("Failed to create table from column references");
(res, vec![count_column.len()])
}

fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) {
builder.request_post_result_challenges(2);
(res, vec![count_column.len()])
}

#[tracing::instrument(name = "GroupByExec::final_round_evaluate", level = "debug", skip_all)]
Expand Down
5 changes: 2 additions & 3 deletions crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ impl ProofPlan for ProjectionExec {

impl ProverEvaluate for ProjectionExec {
#[tracing::instrument(name = "ProjectionExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a, S: Scalar>(
fn first_round_evaluate<'a, S: Scalar>(
&self,
_builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
Expand All @@ -120,8 +121,6 @@ impl ProverEvaluate for ProjectionExec {
)
}

fn first_round_evaluate(&self, _builder: &mut FirstRoundBuilder) {}

#[tracing::instrument(
name = "ProjectionExec::final_round_evaluate",
level = "debug",
Expand Down
Loading

0 comments on commit 31a811e

Please sign in to comment.