diff --git a/acir/src/circuit/opcodes.rs b/acir/src/circuit/opcodes.rs index ea6b5fac3..d422ffb17 100644 --- a/acir/src/circuit/opcodes.rs +++ b/acir/src/circuit/opcodes.rs @@ -24,6 +24,8 @@ pub enum Opcode { MemoryOp { block_id: BlockId, op: MemOp, + /// Predicate of the memory operation - indicates if it should be skipped + predicate: Option, }, MemoryInit { block_id: BlockId, @@ -158,8 +160,12 @@ impl std::fmt::Display for Opcode { writeln!(f, "outputs: {:?}", brillig.outputs)?; writeln!(f, "{:?}", brillig.bytecode) } - Opcode::MemoryOp { block_id, op } => { + Opcode::MemoryOp { block_id, op, predicate } => { write!(f, "MEM ")?; + if let Some(pred) = predicate { + writeln!(f, "PREDICATE = {pred}")?; + } + let is_read = op.operation.is_zero(); let is_write = op.operation == Expression::one(); if is_read { diff --git a/acir/tests/test_program_serialization.rs b/acir/tests/test_program_serialization.rs index b2d3a0e16..6a0c3c290 100644 --- a/acir/tests/test_program_serialization.rs +++ b/acir/tests/test_program_serialization.rs @@ -15,7 +15,7 @@ use acir::{ circuit::{ brillig::{Brillig, BrilligInputs, BrilligOutputs}, directives::Directive, - opcodes::{BlackBoxFuncCall, FunctionInput}, + opcodes::{BlackBoxFuncCall, BlockId, FunctionInput, MemOp}, Circuit, Opcode, PublicInputs, }, native_types::{Expression, Witness}, @@ -340,3 +340,40 @@ fn complex_brillig_foreign_call() { assert_eq!(bytes, expected_serialization) } + +#[test] +fn memory_op_circuit() { + let init = vec![Witness(1), Witness(2)]; + + let memory_init = Opcode::MemoryInit { block_id: BlockId(0), init }; + let write = Opcode::MemoryOp { + block_id: BlockId(0), + op: MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()), + predicate: None, + }; + let read = Opcode::MemoryOp { + block_id: BlockId(0), + op: MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)), + predicate: None, + }; + + let circuit = Circuit { + current_witness_index: 5, + opcodes: vec![memory_init, write, read], + private_parameters: BTreeSet::from([Witness(1), Witness(2), Witness(3)]), + return_values: PublicInputs([Witness(4)].into()), + ..Circuit::default() + }; + let mut bytes = Vec::new(); + circuit.write(&mut bytes).unwrap(); + + let expected_serialization: Vec = vec![ + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 146, 49, 14, 0, 32, 8, 3, 139, 192, 127, 240, 7, + 254, 255, 85, 198, 136, 9, 131, 155, 48, 216, 165, 76, 77, 57, 80, 0, 140, 45, 117, 111, + 238, 228, 179, 224, 174, 225, 110, 111, 234, 213, 185, 148, 156, 203, 121, 89, 86, 13, 215, + 126, 131, 43, 153, 187, 115, 40, 185, 62, 153, 3, 136, 83, 60, 30, 96, 2, 12, 235, 225, + 124, 14, 3, 0, 0, + ]; + + assert_eq!(bytes, expected_serialization) +} diff --git a/acvm/src/pwg/memory_op.rs b/acvm/src/pwg/memory_op.rs index fadc1bd60..7def24c0e 100644 --- a/acvm/src/pwg/memory_op.rs +++ b/acvm/src/pwg/memory_op.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use acir::{ circuit::opcodes::MemOp, - native_types::{Witness, WitnessMap}, + native_types::{Expression, Witness, WitnessMap}, FieldElement, }; @@ -63,6 +63,7 @@ impl MemoryOpSolver { &mut self, op: &MemOp, initial_witness: &mut WitnessMap, + predicate: &Option, ) -> Result<(), OpcodeResolutionError> { let operation = get_value(&op.operation, initial_witness)?; @@ -79,6 +80,12 @@ impl MemoryOpSolver { // `operation == 0` implies a read operation. (`operation == 1` implies write operation). let is_read_operation = operation.is_zero(); + // If the predicate is `None`, then we simply return the value 1 + let pred_value = match predicate { + Some(pred) => get_value(pred, initial_witness), + None => Ok(FieldElement::one()), + }?; + if is_read_operation { // `value_read = arr[memory_index]` // @@ -88,7 +95,13 @@ impl MemoryOpSolver { "Memory must be read into a specified witness index, encountered an Expression", ); - let value_in_array = self.read_memory_index(memory_index)?; + // A zero predicate indicates that we should skip the read operation + // and zero out the operation's output. + let value_in_array = if pred_value.is_zero() { + FieldElement::zero() + } else { + self.read_memory_index(memory_index)? + }; insert_value(&value_read_witness, value_in_array, initial_witness) } else { // `arr[memory_index] = value_write` @@ -97,9 +110,15 @@ impl MemoryOpSolver { // into the memory block. let value_write = value; - let value_to_write = get_value(&value_write, initial_witness)?; - - self.write_memory_index(memory_index, value_to_write) + // A zero predicate indicates that we should skip the write operation. + if pred_value.is_zero() { + // We only want to write to already initialized memory. + // Do nothing if the predicate is zero. + return Ok(()); + } else { + let value_to_write = get_value(&value_write, initial_witness)?; + self.write_memory_index(memory_index, value_to_write) + } } } } @@ -110,7 +129,7 @@ mod tests { use acir::{ circuit::opcodes::MemOp, - native_types::{Witness, WitnessMap}, + native_types::{Expression, Witness, WitnessMap}, FieldElement, }; @@ -135,8 +154,9 @@ mod tests { block_solver.init(&init, &initial_witness).unwrap(); for op in trace { - block_solver.solve_memory_op(&op, &mut initial_witness).unwrap(); + block_solver.solve_memory_op(&op, &mut initial_witness, &None).unwrap(); } + assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128)); } @@ -159,9 +179,10 @@ mod tests { let mut err = None; for op in invalid_trace { if err.is_none() { - err = block_solver.solve_memory_op(&op, &mut initial_witness).err(); + err = block_solver.solve_memory_op(&op, &mut initial_witness, &None).err(); } } + assert!(matches!( err, Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds { @@ -171,4 +192,68 @@ mod tests { }) )); } + + #[test] + fn test_predicate_on_read() { + let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([ + (Witness(1), FieldElement::from(1u128)), + (Witness(2), FieldElement::from(1u128)), + (Witness(3), FieldElement::from(2u128)), + ])); + + let init = vec![Witness(1), Witness(2)]; + + let invalid_trace = vec![ + MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()), + MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)), + ]; + let mut block_solver = MemoryOpSolver::default(); + block_solver.init(&init, &initial_witness).unwrap(); + let mut err = None; + for op in invalid_trace { + if err.is_none() { + err = block_solver + .solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero())) + .err(); + } + } + + // Should have no index out of bounds error where predicate is zero + assert_eq!(err, None); + // The result of a read under a zero predicate should be zero + assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128)); + } + + #[test] + fn test_predicate_on_write() { + let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([ + (Witness(1), FieldElement::from(1u128)), + (Witness(2), FieldElement::from(1u128)), + (Witness(3), FieldElement::from(2u128)), + ])); + + let init = vec![Witness(1), Witness(2)]; + + let invalid_trace = vec![ + MemOp::write_to_mem_index(FieldElement::from(2u128).into(), Witness(3).into()), + MemOp::read_at_mem_index(FieldElement::from(0u128).into(), Witness(4).into()), + MemOp::read_at_mem_index(FieldElement::from(1u128).into(), Witness(5).into()), + ]; + let mut block_solver = MemoryOpSolver::default(); + block_solver.init(&init, &initial_witness).unwrap(); + let mut err = None; + for op in invalid_trace { + if err.is_none() { + err = block_solver + .solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero())) + .err(); + } + } + + // Should have no index out of bounds error where predicate is zero + assert_eq!(err, None); + // The memory under a zero predicate should be zeroed out + assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128)); + assert_eq!(initial_witness[&Witness(5)], FieldElement::from(0u128)); + } } diff --git a/acvm/src/pwg/mod.rs b/acvm/src/pwg/mod.rs index 350c4c922..8a39311b3 100644 --- a/acvm/src/pwg/mod.rs +++ b/acvm/src/pwg/mod.rs @@ -253,9 +253,9 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { let solver = self.block_solvers.entry(*block_id).or_default(); solver.init(init, &self.witness_map) } - Opcode::MemoryOp { block_id, op } => { + Opcode::MemoryOp { block_id, op, predicate } => { let solver = self.block_solvers.entry(*block_id).or_default(); - solver.solve_memory_op(op, &mut self.witness_map) + solver.solve_memory_op(op, &mut self.witness_map, predicate) } Opcode::Brillig(brillig) => { match BrilligSolver::solve(&mut self.witness_map, brillig, self.backend) { diff --git a/acvm/tests/solver.rs b/acvm/tests/solver.rs index b67654094..38bd51701 100644 --- a/acvm/tests/solver.rs +++ b/acvm/tests/solver.rs @@ -644,8 +644,11 @@ fn memory_operations() { let init = Opcode::MemoryInit { block_id, init: (1..6).map(Witness).collect() }; - let read_op = - Opcode::MemoryOp { block_id, op: MemOp::read_at_mem_index(Witness(6).into(), Witness(7)) }; + let read_op = Opcode::MemoryOp { + block_id, + op: MemOp::read_at_mem_index(Witness(6).into(), Witness(7)), + predicate: None, + }; let expression = Opcode::Arithmetic(Expression { mul_terms: Vec::new(), diff --git a/acvm_js/test/browser/execute_circuit.test.ts b/acvm_js/test/browser/execute_circuit.test.ts index dae83d0cd..407aa830c 100644 --- a/acvm_js/test/browser/execute_circuit.test.ts +++ b/acvm_js/test/browser/execute_circuit.test.ts @@ -162,6 +162,22 @@ it("successfully executes a SchnorrVerify opcode", async () => { expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); }); +it("successfully executes a MemoryOp opcode", async () => { + const { bytecode, initialWitnessMap, expectedWitnessMap } = await import( + "../shared/memory_op" + ); + + const solvedWitness: WitnessMap = await executeCircuit( + bytecode, + initialWitnessMap, + () => { + throw Error("unexpected oracle"); + } + ); + + expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); +}); + it("successfully executes two circuits with same backend", async function () { // chose pedersen op here because it is the one with slow initialization // that led to the decision to pull backend initialization into a separate diff --git a/acvm_js/test/node/execute_circuit.test.ts b/acvm_js/test/node/execute_circuit.test.ts index 08d49836f..3b84d8d9c 100644 --- a/acvm_js/test/node/execute_circuit.test.ts +++ b/acvm_js/test/node/execute_circuit.test.ts @@ -156,6 +156,22 @@ it("successfully executes a SchnorrVerify opcode", async () => { expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); }); +it("successfully executes a MemoryOp opcode", async () => { + const { bytecode, initialWitnessMap, expectedWitnessMap } = await import( + "../shared/memory_op" + ); + + const solvedWitness: WitnessMap = await executeCircuit( + bytecode, + initialWitnessMap, + () => { + throw Error("unexpected oracle"); + } + ); + + expect(solvedWitness).to.be.deep.eq(expectedWitnessMap); +}); + it("successfully executes two circuits with same backend", async function () { this.timeout(10000); diff --git a/acvm_js/test/shared/memory_op.ts b/acvm_js/test/shared/memory_op.ts new file mode 100644 index 000000000..ffb37df34 --- /dev/null +++ b/acvm_js/test/shared/memory_op.ts @@ -0,0 +1,21 @@ +// See `memory_op_circuit` integration test in `acir/tests/test_program_serialization.rs`. +export const bytecode = Uint8Array.from([ + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 146, 49, 14, 0, 32, 8, 3, 139, 192, + 127, 240, 7, 254, 255, 85, 198, 136, 9, 131, 155, 48, 216, 165, 76, 77, 57, + 80, 0, 140, 45, 117, 111, 238, 228, 179, 224, 174, 225, 110, 111, 234, 213, + 185, 148, 156, 203, 121, 89, 86, 13, 215, 126, 131, 43, 153, 187, 115, 40, + 185, 62, 153, 3, 136, 83, 60, 30, 96, 2, 12, 235, 225, 124, 14, 3, 0, 0, +]); + +export const initialWitnessMap = new Map([ + [1, "0x0000000000000000000000000000000000000000000000000000000000000001"], + [2, "0x0000000000000000000000000000000000000000000000000000000000000001"], + [3, "0x0000000000000000000000000000000000000000000000000000000000000002"], +]); + +export const expectedWitnessMap = new Map([ + [1, "0x0000000000000000000000000000000000000000000000000000000000000001"], + [2, "0x0000000000000000000000000000000000000000000000000000000000000001"], + [3, "0x0000000000000000000000000000000000000000000000000000000000000002"], + [4, "0x0000000000000000000000000000000000000000000000000000000000000002"], +]);