From 3358f31f6fd2334265d96bd3dd58bdebc7a14e49 Mon Sep 17 00:00:00 2001 From: Tom French Date: Fri, 15 Nov 2024 17:36:59 +0000 Subject: [PATCH] fix: take blackbox function outputs into account when merging expressions --- .../compiler/optimizers/merge_expressions.rs | 72 +++++++++++++++---- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index c3c80bec2ae..e2585f64acd 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -153,15 +153,18 @@ impl MergeExpressionsOptimizer { // Returns the input witnesses used by the opcode fn witness_inputs(&self, opcode: &Opcode) -> BTreeSet { - let mut witnesses = BTreeSet::new(); match opcode { Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr), - Opcode::BlackBoxFuncCall(bb_func) => bb_func.get_input_witnesses(), + Opcode::BlackBoxFuncCall(bb_func) => { + let mut witnesses = bb_func.get_input_witnesses(); + witnesses.extend(bb_func.get_outputs_vec()); + + witnesses + } Opcode::Directive(Directive::ToLeRadix { a, .. }) => CircuitSimulator::expr_wit(a), Opcode::MemoryOp { block_id: _, op, predicate } => { //index et value, et predicate - let mut witnesses = BTreeSet::new(); - witnesses.extend(CircuitSimulator::expr_wit(&op.index)); + let mut witnesses = CircuitSimulator::expr_wit(&op.index); witnesses.extend(CircuitSimulator::expr_wit(&op.value)); if let Some(p) = predicate { witnesses.extend(CircuitSimulator::expr_wit(p)); @@ -173,6 +176,7 @@ impl MergeExpressionsOptimizer { init.iter().cloned().collect() } Opcode::BrilligCall { inputs, outputs, .. } => { + let mut witnesses = BTreeSet::new(); for i in inputs { witnesses.extend(self.brillig_input_wit(i)); } @@ -182,12 +186,9 @@ impl MergeExpressionsOptimizer { witnesses } Opcode::Call { id: _, inputs, outputs, predicate } => { - for i in inputs { - witnesses.insert(*i); - } - for i in outputs { - witnesses.insert(*i); - } + let mut witnesses: BTreeSet = BTreeSet::from_iter(inputs.iter().copied()); + witnesses.extend(outputs); + if let Some(p) = predicate { witnesses.extend(CircuitSimulator::expr_wit(p)); } @@ -235,7 +236,7 @@ mod tests { acir_field::AcirField, circuit::{ brillig::{BrilligFunctionId, BrilligOutputs}, - opcodes::FunctionInput, + opcodes::{BlackBoxFuncCall, FunctionInput}, Circuit, ExpressionWidth, Opcode, PublicInputs, }, native_types::{Expression, Witness}, @@ -243,7 +244,7 @@ mod tests { }; use std::collections::BTreeSet; - fn check_circuit(circuit: Circuit) { + fn check_circuit(circuit: Circuit) -> Circuit { assert!(CircuitSimulator::default().check_circuit(&circuit)); let mut merge_optimizer = MergeExpressionsOptimizer::new(); let acir_opcode_positions = vec![0; 20]; @@ -253,6 +254,7 @@ mod tests { optimized_circuit.opcodes = opcodes; // check that the circuit is still valid after optimization assert!(CircuitSimulator::default().check_circuit(&optimized_circuit)); + optimized_circuit } #[test] @@ -352,4 +354,50 @@ mod tests { }; check_circuit(circuit); } + + #[test] + fn takes_blackbox_opcode_outputs_into_account() { + // Regression test for https://github.com/noir-lang/noir/issues/6527 + // Previously we would not track the usage of witness 4 in the output of the blackbox function. + // We would then merge the final two opcodes losing the check that the brillig call must match + // with `_0 ^ _1`. + + let circuit: Circuit = Circuit { + current_witness_index: 7, + opcodes: vec![ + Opcode::BrilligCall { + id: BrilligFunctionId(0), + inputs: Vec::new(), + outputs: vec![BrilligOutputs::Simple(Witness(3))], + predicate: None, + }, + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND { + lhs: FunctionInput::witness(Witness(0), 8), + rhs: FunctionInput::witness(Witness(1), 8), + output: Witness(4), + }), + Opcode::AssertZero(Expression { + linear_combinations: vec![ + (FieldElement::one(), Witness(3)), + (-FieldElement::one(), Witness(4)), + ], + ..Default::default() + }), + Opcode::AssertZero(Expression { + linear_combinations: vec![ + (-FieldElement::one(), Witness(2)), + (FieldElement::one(), Witness(4)), + ], + ..Default::default() + }), + ], + expression_width: ExpressionWidth::Bounded { width: 4 }, + private_parameters: BTreeSet::from([Witness(0), Witness(1)]), + return_values: PublicInputs(BTreeSet::from([Witness(2)])), + ..Default::default() + }; + + let new_circuit = check_circuit(circuit.clone()); + assert_eq!(circuit, new_circuit); + } }