Skip to content

Commit

Permalink
Fold over the witnesses
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh committed Dec 4, 2024
1 parent 02c53bd commit 09fb747
Showing 1 changed file with 123 additions and 111 deletions.
234 changes: 123 additions & 111 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::BTreeSet;

use acir::{
circuit::{
self,
Expand Down Expand Up @@ -163,166 +161,180 @@ pub(super) fn transform_internal<F: AcirField>(
..acir
};
let mut merge_optimizer = MergeExpressionsOptimizer::new();

let (opcodes, new_acir_opcode_positions) =
merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);

// n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let mut acir = Circuit {
current_witness_index,
expression_width,
opcodes,
// The optimizer does not add new public inputs
..acir
};

// After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be,
// which would cause gaps if we ran the optimization a second time, making it look like new variables were added.
// Here we figure out what is the final state of witnesses by visiting each opcode.
let witnesses = WitnessCollector::collect_from_circuit(&acir);
if let Some(max_witness) = witnesses.last() {
acir.current_witness_index = max_witness.0;
}
acir.current_witness_index = max_witness(&acir).witness_index();

(acir, new_acir_opcode_positions)
}

/// Collect all witnesses in a circuit.
#[derive(Default, Clone, Debug)]
struct WitnessCollector {
witnesses: BTreeSet<Witness>,
/// Find the witness with the highest ID in the circuit.
fn max_witness<F: AcirField>(circuit: &Circuit<F>) -> Witness {
let mut witnesses = WitnessFolder::new(Witness::default(), |state, witness| {
*state = witness.max(*state);
});
witnesses.fold_circuit(circuit);
witnesses.into_state()
}

impl WitnessCollector {
/// Collect all witnesses in a circuit.
fn collect_from_circuit<F: AcirField>(circuit: &Circuit<F>) -> BTreeSet<Witness> {
let mut collector = Self::default();
collector.extend_from_circuit(circuit);
collector.witnesses
}
/// Fold all witnesses in a circuit.
struct WitnessFolder<S, A> {
state: S,
accumulate: A,
}

fn add(&mut self, witness: Witness) {
self.witnesses.insert(witness);
impl<S, A> WitnessFolder<S, A>
where
A: Fn(&mut S, Witness),
{
/// Create the folder with some initial state and an accumulator function.
fn new(init: S, accumulate: A) -> Self {
Self { state: init, accumulate }
}

fn add_many(&mut self, witnesses: &[Witness]) {
self.witnesses.extend(witnesses);
/// Take the accumulated state.
fn into_state(self) -> S {
self.state
}

/// Add all witnesses from the circuit.
fn extend_from_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) {
self.witnesses.extend(&circuit.private_parameters);
self.witnesses.extend(&circuit.public_parameters.0);
self.witnesses.extend(&circuit.return_values.0);
fn fold_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) {
self.fold_many(circuit.private_parameters.iter());
self.fold_many(circuit.public_parameters.0.iter());
self.fold_many(circuit.return_values.0.iter());
for opcode in &circuit.opcodes {
self.extend_from_opcode(opcode);
self.fold_opcode(opcode);
}
}

/// Fold a witness into the state.
fn fold(&mut self, witness: Witness) {
(self.accumulate)(&mut self.state, witness);
}

/// Fold many witnesses into the state.
fn fold_many<'w, I: Iterator<Item = &'w Witness>>(&mut self, witnesses: I) {
for w in witnesses {
self.fold(*w);
}
}

/// Add witnesses from the opcode.
fn extend_from_opcode<F: AcirField>(&mut self, opcode: &Opcode<F>) {
fn fold_opcode<F: AcirField>(&mut self, opcode: &Opcode<F>) {
match opcode {
Opcode::AssertZero(expr) => {
self.extend_from_expr(expr);
self.fold_expr(expr);
}
Opcode::BlackBoxFuncCall(call) => self.extend_from_blackbox(call),
Opcode::BlackBoxFuncCall(call) => self.fold_blackbox(call),
Opcode::MemoryOp { block_id: _, op, predicate } => {
let MemOp { operation, index, value } = op;
self.extend_from_expr(operation);
self.extend_from_expr(index);
self.extend_from_expr(value);
self.fold_expr(operation);
self.fold_expr(index);
self.fold_expr(value);
if let Some(pred) = predicate {
self.extend_from_expr(pred);
self.fold_expr(pred);
}
}
Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
for w in init {
self.add(*w);
self.fold(*w);
}
}
// We keep the display for a BrilligCall and circuit Call separate as they
// are distinct in their functionality and we should maintain this separation for debugging.
Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
if let Some(pred) = predicate {
self.extend_from_expr(pred);
self.fold_expr(pred);
}
self.extend_from_brillig_inputs(inputs);
self.extend_from_brillig_outputs(outputs);
self.fold_brillig_inputs(inputs);
self.fold_brillig_outputs(outputs);
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
if let Some(pred) = predicate {
self.extend_from_expr(pred);
self.fold_expr(pred);
}
self.add_many(inputs);
self.add_many(outputs);
self.fold_many(inputs.iter());
self.fold_many(outputs.iter());
}
}
}

fn extend_from_expr<F: AcirField>(&mut self, expr: &Expression<F>) {
fn fold_expr<F: AcirField>(&mut self, expr: &Expression<F>) {
for i in &expr.mul_terms {
self.add(i.1);
self.add(i.2);
self.fold(i.1);
self.fold(i.2);
}
for i in &expr.linear_combinations {
self.add(i.1);
self.fold(i.1);
}
}

fn extend_from_brillig_inputs<F: AcirField>(&mut self, inputs: &[BrilligInputs<F>]) {
fn fold_brillig_inputs<F: AcirField>(&mut self, inputs: &[BrilligInputs<F>]) {
for input in inputs {
match input {
BrilligInputs::Single(expr) => {
self.extend_from_expr(expr);
self.fold_expr(expr);
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
self.extend_from_expr(expr);
self.fold_expr(expr);
}
}
BrilligInputs::MemoryArray(_) => {}
}
}
}

fn extend_from_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) {
fn fold_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) {
for output in outputs {
match output {
BrilligOutputs::Simple(w) => {
self.add(*w);
self.fold(*w);
}
BrilligOutputs::Array(ws) => self.add_many(ws),
BrilligOutputs::Array(ws) => self.fold_many(ws.iter()),
}
}
}

fn extend_from_blackbox<F: AcirField>(&mut self, call: &BlackBoxFuncCall<F>) {
fn fold_blackbox<F: AcirField>(&mut self, call: &BlackBoxFuncCall<F>) {
match call {
BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.extend_from_function_inputs(iv.as_slice());
self.extend_from_function_inputs(key.as_slice());
self.add_many(outputs);
self.fold_function_inputs(inputs.as_slice());
self.fold_function_inputs(iv.as_slice());
self.fold_function_inputs(key.as_slice());
self.fold_many(outputs.iter());
}
BlackBoxFuncCall::AND { lhs, rhs, output } => {
self.extend_from_function_input(lhs);
self.extend_from_function_input(rhs);
self.add(*output);
self.fold_function_input(lhs);
self.fold_function_input(rhs);
self.fold(*output);
}
BlackBoxFuncCall::XOR { lhs, rhs, output } => {
self.extend_from_function_input(lhs);
self.extend_from_function_input(rhs);
self.add(*output);
self.fold_function_input(lhs);
self.fold_function_input(rhs);
self.fold(*output);
}
BlackBoxFuncCall::RANGE { input } => {
self.extend_from_function_input(input);
self.fold_function_input(input);
}
BlackBoxFuncCall::Blake2s { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
self.fold_function_inputs(inputs.as_slice());
self.fold_many(outputs.iter());
}
BlackBoxFuncCall::Blake3 { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
self.fold_function_inputs(inputs.as_slice());
self.fold_many(outputs.iter());
}
BlackBoxFuncCall::SchnorrVerify {
public_key_x,
Expand All @@ -331,11 +343,11 @@ impl WitnessCollector {
message,
output,
} => {
self.extend_from_function_input(public_key_x);
self.extend_from_function_input(public_key_y);
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(message.as_slice());
self.add(*output);
self.fold_function_input(public_key_x);
self.fold_function_input(public_key_y);
self.fold_function_inputs(signature.as_slice());
self.fold_function_inputs(message.as_slice());
self.fold(*output);
}
BlackBoxFuncCall::EcdsaSecp256k1 {
public_key_x,
Expand All @@ -344,11 +356,11 @@ impl WitnessCollector {
hashed_message,
output,
} => {
self.extend_from_function_inputs(public_key_x.as_slice());
self.extend_from_function_inputs(public_key_y.as_slice());
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(hashed_message.as_slice());
self.add(*output);
self.fold_function_inputs(public_key_x.as_slice());
self.fold_function_inputs(public_key_y.as_slice());
self.fold_function_inputs(signature.as_slice());
self.fold_function_inputs(hashed_message.as_slice());
self.fold(*output);
}
BlackBoxFuncCall::EcdsaSecp256r1 {
public_key_x,
Expand All @@ -357,31 +369,31 @@ impl WitnessCollector {
hashed_message,
output,
} => {
self.extend_from_function_inputs(public_key_x.as_slice());
self.extend_from_function_inputs(public_key_y.as_slice());
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(hashed_message.as_slice());
self.add(*output);
self.fold_function_inputs(public_key_x.as_slice());
self.fold_function_inputs(public_key_y.as_slice());
self.fold_function_inputs(signature.as_slice());
self.fold_function_inputs(hashed_message.as_slice());
self.fold(*output);
}
BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => {
self.extend_from_function_inputs(points.as_slice());
self.extend_from_function_inputs(scalars.as_slice());
self.fold_function_inputs(points.as_slice());
self.fold_function_inputs(scalars.as_slice());
let (x, y, i) = outputs;
self.add(*x);
self.add(*y);
self.add(*i);
self.fold(*x);
self.fold(*y);
self.fold(*i);
}
BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => {
self.extend_from_function_inputs(input1.as_slice());
self.extend_from_function_inputs(input2.as_slice());
self.fold_function_inputs(input1.as_slice());
self.fold_function_inputs(input2.as_slice());
let (x, y, i) = outputs;
self.add(*x);
self.add(*y);
self.add(*i);
self.fold(*x);
self.fold(*y);
self.fold(*i);
}
BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
self.fold_function_inputs(inputs.as_slice());
self.fold_many(outputs.iter());
}
BlackBoxFuncCall::RecursiveAggregation {
verification_key,
Expand All @@ -390,42 +402,42 @@ impl WitnessCollector {
key_hash,
proof_type: _,
} => {
self.extend_from_function_inputs(verification_key.as_slice());
self.extend_from_function_inputs(proof.as_slice());
self.extend_from_function_inputs(public_inputs.as_slice());
self.extend_from_function_input(key_hash);
self.fold_function_inputs(verification_key.as_slice());
self.fold_function_inputs(proof.as_slice());
self.fold_function_inputs(public_inputs.as_slice());
self.fold_function_input(key_hash);
}
BlackBoxFuncCall::BigIntAdd { .. }
| BlackBoxFuncCall::BigIntSub { .. }
| BlackBoxFuncCall::BigIntMul { .. }
| BlackBoxFuncCall::BigIntDiv { .. } => {}
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus: _, output: _ } => {
self.extend_from_function_inputs(inputs.as_slice());
self.fold_function_inputs(inputs.as_slice());
}
BlackBoxFuncCall::BigIntToLeBytes { input: _, outputs } => {
self.add_many(outputs.as_slice());
self.fold_many(outputs.iter());
}
BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len: _ } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
self.fold_function_inputs(inputs.as_slice());
self.fold_many(outputs.iter());
}
BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.extend_from_function_inputs(hash_values.as_slice());
self.add_many(outputs.as_slice());
self.fold_function_inputs(inputs.as_slice());
self.fold_function_inputs(hash_values.as_slice());
self.fold_many(outputs.iter());
}
}
}

fn extend_from_function_input<F: AcirField>(&mut self, input: &FunctionInput<F>) {
fn fold_function_input<F: AcirField>(&mut self, input: &FunctionInput<F>) {
if let circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) = input.input() {
self.add(witness);
self.fold(witness);
}
}

fn extend_from_function_inputs<F: AcirField>(&mut self, inputs: &[FunctionInput<F>]) {
fn fold_function_inputs<F: AcirField>(&mut self, inputs: &[FunctionInput<F>]) {
for input in inputs {
self.extend_from_function_input(input);
self.fold_function_input(input);
}
}
}

0 comments on commit 09fb747

Please sign in to comment.