diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index c92b0516431..8f9ef63be1f 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeSet; - use acir::{ circuit::{ self, @@ -163,13 +161,12 @@ pub(super) fn transform_internal( ..acir }; let mut merge_optimizer = MergeExpressionsOptimizer::new(); + let (opcodes, new_acir_opcode_positions) = merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions); // n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. let mut acir = Circuit { - current_witness_index, - expression_width, opcodes, // The optimizer does not add new public inputs ..acir @@ -177,106 +174,121 @@ pub(super) fn transform_internal( // After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be, // which would cause gaps if we ran the optimization a second time, making it look like new variables were added. - // Here we figure out what is the final state of witnesses by visiting each opcode. - let witnesses = WitnessCollector::collect_from_circuit(&acir); - if let Some(max_witness) = witnesses.last() { - acir.current_witness_index = max_witness.0; - } + acir.current_witness_index = max_witness(&acir).witness_index(); (acir, new_acir_opcode_positions) } -/// Collect all witnesses in a circuit. -#[derive(Default, Clone, Debug)] -struct WitnessCollector { - witnesses: BTreeSet, +/// Find the witness with the highest ID in the circuit. +fn max_witness(circuit: &Circuit) -> Witness { + let mut witnesses = WitnessFolder::new(Witness::default(), |state, witness| { + *state = witness.max(*state); + }); + witnesses.fold_circuit(circuit); + witnesses.into_state() } -impl WitnessCollector { - /// Collect all witnesses in a circuit. - fn collect_from_circuit(circuit: &Circuit) -> BTreeSet { - let mut collector = Self::default(); - collector.extend_from_circuit(circuit); - collector.witnesses - } +/// Fold all witnesses in a circuit. +struct WitnessFolder { + state: S, + accumulate: A, +} - fn add(&mut self, witness: Witness) { - self.witnesses.insert(witness); +impl WitnessFolder +where + A: Fn(&mut S, Witness), +{ + /// Create the folder with some initial state and an accumulator function. + fn new(init: S, accumulate: A) -> Self { + Self { state: init, accumulate } } - fn add_many(&mut self, witnesses: &[Witness]) { - self.witnesses.extend(witnesses); + /// Take the accumulated state. + fn into_state(self) -> S { + self.state } /// Add all witnesses from the circuit. - fn extend_from_circuit(&mut self, circuit: &Circuit) { - self.witnesses.extend(&circuit.private_parameters); - self.witnesses.extend(&circuit.public_parameters.0); - self.witnesses.extend(&circuit.return_values.0); + fn fold_circuit(&mut self, circuit: &Circuit) { + self.fold_many(circuit.private_parameters.iter()); + self.fold_many(circuit.public_parameters.0.iter()); + self.fold_many(circuit.return_values.0.iter()); for opcode in &circuit.opcodes { - self.extend_from_opcode(opcode); + self.fold_opcode(opcode); + } + } + + /// Fold a witness into the state. + fn fold(&mut self, witness: Witness) { + (self.accumulate)(&mut self.state, witness); + } + + /// Fold many witnesses into the state. + fn fold_many<'w, I: Iterator>(&mut self, witnesses: I) { + for w in witnesses { + self.fold(*w); } } /// Add witnesses from the opcode. - fn extend_from_opcode(&mut self, opcode: &Opcode) { + fn fold_opcode(&mut self, opcode: &Opcode) { match opcode { Opcode::AssertZero(expr) => { - self.extend_from_expr(expr); + self.fold_expr(expr); } - Opcode::BlackBoxFuncCall(call) => self.extend_from_blackbox(call), + Opcode::BlackBoxFuncCall(call) => self.fold_blackbox(call), Opcode::MemoryOp { block_id: _, op, predicate } => { let MemOp { operation, index, value } = op; - self.extend_from_expr(operation); - self.extend_from_expr(index); - self.extend_from_expr(value); + self.fold_expr(operation); + self.fold_expr(index); + self.fold_expr(value); if let Some(pred) = predicate { - self.extend_from_expr(pred); + self.fold_expr(pred); } } Opcode::MemoryInit { block_id: _, init, block_type: _ } => { for w in init { - self.add(*w); + self.fold(*w); } } // We keep the display for a BrilligCall and circuit Call separate as they // are distinct in their functionality and we should maintain this separation for debugging. Opcode::BrilligCall { id: _, inputs, outputs, predicate } => { if let Some(pred) = predicate { - self.extend_from_expr(pred); + self.fold_expr(pred); } - self.extend_from_brillig_inputs(inputs); - self.extend_from_brillig_outputs(outputs); + self.fold_brillig_inputs(inputs); + self.fold_brillig_outputs(outputs); } Opcode::Call { id: _, inputs, outputs, predicate } => { if let Some(pred) = predicate { - self.extend_from_expr(pred); + self.fold_expr(pred); } - self.add_many(inputs); - self.add_many(outputs); + self.fold_many(inputs.iter()); + self.fold_many(outputs.iter()); } } } - fn extend_from_expr(&mut self, expr: &Expression) { + fn fold_expr(&mut self, expr: &Expression) { for i in &expr.mul_terms { - self.add(i.1); - self.add(i.2); + self.fold(i.1); + self.fold(i.2); } for i in &expr.linear_combinations { - self.add(i.1); + self.fold(i.1); } } - fn extend_from_brillig_inputs(&mut self, inputs: &[BrilligInputs]) { + fn fold_brillig_inputs(&mut self, inputs: &[BrilligInputs]) { for input in inputs { match input { BrilligInputs::Single(expr) => { - self.extend_from_expr(expr); + self.fold_expr(expr); } BrilligInputs::Array(exprs) => { for expr in exprs { - self.extend_from_expr(expr); + self.fold_expr(expr); } } BrilligInputs::MemoryArray(_) => {} @@ -284,45 +296,45 @@ impl WitnessCollector { } } - fn extend_from_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) { + fn fold_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) { for output in outputs { match output { BrilligOutputs::Simple(w) => { - self.add(*w); + self.fold(*w); } - BrilligOutputs::Array(ws) => self.add_many(ws), + BrilligOutputs::Array(ws) => self.fold_many(ws.iter()), } } } - fn extend_from_blackbox(&mut self, call: &BlackBoxFuncCall) { + fn fold_blackbox(&mut self, call: &BlackBoxFuncCall) { match call { BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => { - self.extend_from_function_inputs(inputs.as_slice()); - self.extend_from_function_inputs(iv.as_slice()); - self.extend_from_function_inputs(key.as_slice()); - self.add_many(outputs); + self.fold_function_inputs(inputs.as_slice()); + self.fold_function_inputs(iv.as_slice()); + self.fold_function_inputs(key.as_slice()); + self.fold_many(outputs.iter()); } BlackBoxFuncCall::AND { lhs, rhs, output } => { - self.extend_from_function_input(lhs); - self.extend_from_function_input(rhs); - self.add(*output); + self.fold_function_input(lhs); + self.fold_function_input(rhs); + self.fold(*output); } BlackBoxFuncCall::XOR { lhs, rhs, output } => { - self.extend_from_function_input(lhs); - self.extend_from_function_input(rhs); - self.add(*output); + self.fold_function_input(lhs); + self.fold_function_input(rhs); + self.fold(*output); } BlackBoxFuncCall::RANGE { input } => { - self.extend_from_function_input(input); + self.fold_function_input(input); } BlackBoxFuncCall::Blake2s { inputs, outputs } => { - self.extend_from_function_inputs(inputs.as_slice()); - self.add_many(outputs.as_slice()); + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); } BlackBoxFuncCall::Blake3 { inputs, outputs } => { - self.extend_from_function_inputs(inputs.as_slice()); - self.add_many(outputs.as_slice()); + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); } BlackBoxFuncCall::SchnorrVerify { public_key_x, @@ -331,11 +343,11 @@ impl WitnessCollector { message, output, } => { - self.extend_from_function_input(public_key_x); - self.extend_from_function_input(public_key_y); - self.extend_from_function_inputs(signature.as_slice()); - self.extend_from_function_inputs(message.as_slice()); - self.add(*output); + self.fold_function_input(public_key_x); + self.fold_function_input(public_key_y); + self.fold_function_inputs(signature.as_slice()); + self.fold_function_inputs(message.as_slice()); + self.fold(*output); } BlackBoxFuncCall::EcdsaSecp256k1 { public_key_x, @@ -344,11 +356,11 @@ impl WitnessCollector { hashed_message, output, } => { - self.extend_from_function_inputs(public_key_x.as_slice()); - self.extend_from_function_inputs(public_key_y.as_slice()); - self.extend_from_function_inputs(signature.as_slice()); - self.extend_from_function_inputs(hashed_message.as_slice()); - self.add(*output); + self.fold_function_inputs(public_key_x.as_slice()); + self.fold_function_inputs(public_key_y.as_slice()); + self.fold_function_inputs(signature.as_slice()); + self.fold_function_inputs(hashed_message.as_slice()); + self.fold(*output); } BlackBoxFuncCall::EcdsaSecp256r1 { public_key_x, @@ -357,31 +369,31 @@ impl WitnessCollector { hashed_message, output, } => { - self.extend_from_function_inputs(public_key_x.as_slice()); - self.extend_from_function_inputs(public_key_y.as_slice()); - self.extend_from_function_inputs(signature.as_slice()); - self.extend_from_function_inputs(hashed_message.as_slice()); - self.add(*output); + self.fold_function_inputs(public_key_x.as_slice()); + self.fold_function_inputs(public_key_y.as_slice()); + self.fold_function_inputs(signature.as_slice()); + self.fold_function_inputs(hashed_message.as_slice()); + self.fold(*output); } BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => { - self.extend_from_function_inputs(points.as_slice()); - self.extend_from_function_inputs(scalars.as_slice()); + self.fold_function_inputs(points.as_slice()); + self.fold_function_inputs(scalars.as_slice()); let (x, y, i) = outputs; - self.add(*x); - self.add(*y); - self.add(*i); + self.fold(*x); + self.fold(*y); + self.fold(*i); } BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => { - self.extend_from_function_inputs(input1.as_slice()); - self.extend_from_function_inputs(input2.as_slice()); + self.fold_function_inputs(input1.as_slice()); + self.fold_function_inputs(input2.as_slice()); let (x, y, i) = outputs; - self.add(*x); - self.add(*y); - self.add(*i); + self.fold(*x); + self.fold(*y); + self.fold(*i); } BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => { - self.extend_from_function_inputs(inputs.as_slice()); - self.add_many(outputs.as_slice()); + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); } BlackBoxFuncCall::RecursiveAggregation { verification_key, @@ -390,42 +402,42 @@ impl WitnessCollector { key_hash, proof_type: _, } => { - self.extend_from_function_inputs(verification_key.as_slice()); - self.extend_from_function_inputs(proof.as_slice()); - self.extend_from_function_inputs(public_inputs.as_slice()); - self.extend_from_function_input(key_hash); + self.fold_function_inputs(verification_key.as_slice()); + self.fold_function_inputs(proof.as_slice()); + self.fold_function_inputs(public_inputs.as_slice()); + self.fold_function_input(key_hash); } BlackBoxFuncCall::BigIntAdd { .. } | BlackBoxFuncCall::BigIntSub { .. } | BlackBoxFuncCall::BigIntMul { .. } | BlackBoxFuncCall::BigIntDiv { .. } => {} BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus: _, output: _ } => { - self.extend_from_function_inputs(inputs.as_slice()); + self.fold_function_inputs(inputs.as_slice()); } BlackBoxFuncCall::BigIntToLeBytes { input: _, outputs } => { - self.add_many(outputs.as_slice()); + self.fold_many(outputs.iter()); } BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len: _ } => { - self.extend_from_function_inputs(inputs.as_slice()); - self.add_many(outputs.as_slice()); + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); } BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => { - self.extend_from_function_inputs(inputs.as_slice()); - self.extend_from_function_inputs(hash_values.as_slice()); - self.add_many(outputs.as_slice()); + self.fold_function_inputs(inputs.as_slice()); + self.fold_function_inputs(hash_values.as_slice()); + self.fold_many(outputs.iter()); } } } - fn extend_from_function_input(&mut self, input: &FunctionInput) { + fn fold_function_input(&mut self, input: &FunctionInput) { if let circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) = input.input() { - self.add(witness); + self.fold(witness); } } - fn extend_from_function_inputs(&mut self, inputs: &[FunctionInput]) { + fn fold_function_inputs(&mut self, inputs: &[FunctionInput]) { for input in inputs { - self.extend_from_function_input(input); + self.fold_function_input(input); } } }