Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

feat(acir)!: Add predicate to MemoryOp #503

Merged
merged 10 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub enum Opcode {
MemoryOp {
block_id: BlockId,
op: MemOp,
/// Predicate of the memory operation - indicates if it should be skipped
predicate: Option<Expression>,
},
MemoryInit {
block_id: BlockId,
Expand Down Expand Up @@ -158,8 +160,12 @@ impl std::fmt::Display for Opcode {
writeln!(f, "outputs: {:?}", brillig.outputs)?;
writeln!(f, "{:?}", brillig.bytecode)
}
Opcode::MemoryOp { block_id, op } => {
Opcode::MemoryOp { block_id, op, predicate } => {
write!(f, "MEM ")?;
if let Some(pred) = predicate {
writeln!(f, "PREDICATE = {pred}")?;
}

let is_read = op.operation.is_zero();
let is_write = op.operation == Expression::one();
if is_read {
Expand Down
39 changes: 38 additions & 1 deletion acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use acir::{
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::{BlackBoxFuncCall, FunctionInput},
opcodes::{BlackBoxFuncCall, BlockId, FunctionInput, MemOp},
Circuit, Opcode, PublicInputs,
},
native_types::{Expression, Witness},
Expand Down Expand Up @@ -340,3 +340,40 @@ fn complex_brillig_foreign_call() {

assert_eq!(bytes, expected_serialization)
}

#[test]
fn memory_op_circuit() {
let init = vec![Witness(1), Witness(2)];

let memory_init = Opcode::MemoryInit { block_id: BlockId(0), init };
let write = Opcode::MemoryOp {
block_id: BlockId(0),
op: MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
predicate: None,
};
let read = Opcode::MemoryOp {
block_id: BlockId(0),
op: MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
predicate: None,
};

let circuit = Circuit {
current_witness_index: 5,
opcodes: vec![memory_init, write, read],
private_parameters: BTreeSet::from([Witness(1), Witness(2), Witness(3)]),
return_values: PublicInputs([Witness(4)].into()),
..Circuit::default()
};
let mut bytes = Vec::new();
circuit.write(&mut bytes).unwrap();

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 146, 49, 14, 0, 32, 8, 3, 139, 192, 127, 240, 7,
254, 255, 85, 198, 136, 9, 131, 155, 48, 216, 165, 76, 77, 57, 80, 0, 140, 45, 117, 111,
238, 228, 179, 224, 174, 225, 110, 111, 234, 213, 185, 148, 156, 203, 121, 89, 86, 13, 215,
126, 131, 43, 153, 187, 115, 40, 185, 62, 153, 3, 136, 83, 60, 30, 96, 2, 12, 235, 225,
124, 14, 3, 0, 0,
];

assert_eq!(bytes, expected_serialization)
}
101 changes: 93 additions & 8 deletions acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use acir::{
circuit::opcodes::MemOp,
native_types::{Witness, WitnessMap},
native_types::{Expression, Witness, WitnessMap},
FieldElement,
};

Expand Down Expand Up @@ -63,6 +63,7 @@ impl MemoryOpSolver {
&mut self,
op: &MemOp,
initial_witness: &mut WitnessMap,
predicate: &Option<Expression>,
) -> Result<(), OpcodeResolutionError> {
let operation = get_value(&op.operation, initial_witness)?;

Expand All @@ -79,6 +80,12 @@ impl MemoryOpSolver {
// `operation == 0` implies a read operation. (`operation == 1` implies write operation).
let is_read_operation = operation.is_zero();

// If the predicate is `None`, then we simply return the value 1
let pred_value = match predicate {
Some(pred) => get_value(pred, initial_witness),
None => Ok(FieldElement::one()),
}?;

if is_read_operation {
// `value_read = arr[memory_index]`
//
Expand All @@ -88,7 +95,13 @@ impl MemoryOpSolver {
"Memory must be read into a specified witness index, encountered an Expression",
);

let value_in_array = self.read_memory_index(memory_index)?;
// A zero predicate indicates that we should skip the read operation
// and zero out the operation's output.
let value_in_array = if pred_value.is_zero() {
FieldElement::zero()
} else {
self.read_memory_index(memory_index)?
};
insert_value(&value_read_witness, value_in_array, initial_witness)
} else {
// `arr[memory_index] = value_write`
Expand All @@ -97,9 +110,15 @@ impl MemoryOpSolver {
// into the memory block.
let value_write = value;

let value_to_write = get_value(&value_write, initial_witness)?;

self.write_memory_index(memory_index, value_to_write)
// A zero predicate indicates that we should skip the write operation.
if pred_value.is_zero() {
// We only want to write to already initialized memory.
// Do nothing if the predicate is zero.
return Ok(());
} else {
let value_to_write = get_value(&value_write, initial_witness)?;
self.write_memory_index(memory_index, value_to_write)
}
}
}
}
Expand All @@ -110,7 +129,7 @@ mod tests {

use acir::{
circuit::opcodes::MemOp,
native_types::{Witness, WitnessMap},
native_types::{Expression, Witness, WitnessMap},
FieldElement,
};

Expand All @@ -135,8 +154,9 @@ mod tests {
block_solver.init(&init, &initial_witness).unwrap();

for op in trace {
block_solver.solve_memory_op(&op, &mut initial_witness).unwrap();
block_solver.solve_memory_op(&op, &mut initial_witness, &None).unwrap();
}

assert_eq!(initial_witness[&Witness(4)], FieldElement::from(2u128));
}

Expand All @@ -159,9 +179,10 @@ mod tests {
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver.solve_memory_op(&op, &mut initial_witness).err();
err = block_solver.solve_memory_op(&op, &mut initial_witness, &None).err();
}
}

assert!(matches!(
err,
Some(crate::pwg::OpcodeResolutionError::IndexOutOfBounds {
Expand All @@ -171,4 +192,68 @@ mod tests {
})
));
}

#[test]
fn test_predicate_on_read() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
(Witness(2), FieldElement::from(1u128)),
(Witness(3), FieldElement::from(2u128)),
]));

let init = vec![Witness(1), Witness(2)];

let invalid_trace = vec![
MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
.err();
}
}

// Should have no index out of bounds error where predicate is zero
assert_eq!(err, None);
// The result of a read under a zero predicate should be zero
assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
}

#[test]
fn test_predicate_on_write() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
(Witness(2), FieldElement::from(1u128)),
(Witness(3), FieldElement::from(2u128)),
]));

let init = vec![Witness(1), Witness(2)];

let invalid_trace = vec![
MemOp::write_to_mem_index(FieldElement::from(2u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(0u128).into(), Witness(4).into()),
MemOp::read_at_mem_index(FieldElement::from(1u128).into(), Witness(5).into()),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
err = block_solver
.solve_memory_op(&op, &mut initial_witness, &Some(Expression::zero()))
.err();
}
}

// Should have no index out of bounds error where predicate is zero
assert_eq!(err, None);
// The memory under a zero predicate should be zeroed out
assert_eq!(initial_witness[&Witness(4)], FieldElement::from(0u128));
assert_eq!(initial_witness[&Witness(5)], FieldElement::from(0u128));
}
}
4 changes: 2 additions & 2 deletions acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.init(init, &self.witness_map)
}
Opcode::MemoryOp { block_id, op } => {
Opcode::MemoryOp { block_id, op, predicate } => {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.solve_memory_op(op, &mut self.witness_map)
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::Brillig(brillig) => {
match BrilligSolver::solve(&mut self.witness_map, brillig, self.backend) {
Expand Down
7 changes: 5 additions & 2 deletions acvm/tests/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,11 @@ fn memory_operations() {

let init = Opcode::MemoryInit { block_id, init: (1..6).map(Witness).collect() };

let read_op =
Opcode::MemoryOp { block_id, op: MemOp::read_at_mem_index(Witness(6).into(), Witness(7)) };
let read_op = Opcode::MemoryOp {
block_id,
op: MemOp::read_at_mem_index(Witness(6).into(), Witness(7)),
predicate: None,
};

let expression = Opcode::Arithmetic(Expression {
mul_terms: Vec::new(),
Expand Down
16 changes: 16 additions & 0 deletions acvm_js/test/browser/execute_circuit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ it("successfully executes a SchnorrVerify opcode", async () => {
expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes a MemoryOp opcode", async () => {
const { bytecode, initialWitnessMap, expectedWitnessMap } = await import(
"../shared/memory_op"
);

const solvedWitness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes two circuits with same backend", async function () {
// chose pedersen op here because it is the one with slow initialization
// that led to the decision to pull backend initialization into a separate
Expand Down
16 changes: 16 additions & 0 deletions acvm_js/test/node/execute_circuit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ it("successfully executes a SchnorrVerify opcode", async () => {
expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes a MemoryOp opcode", async () => {
const { bytecode, initialWitnessMap, expectedWitnessMap } = await import(
"../shared/memory_op"
);

const solvedWitness: WitnessMap = await executeCircuit(
bytecode,
initialWitnessMap,
() => {
throw Error("unexpected oracle");
}
);

expect(solvedWitness).to.be.deep.eq(expectedWitnessMap);
});

it("successfully executes two circuits with same backend", async function () {
this.timeout(10000);

Expand Down
21 changes: 21 additions & 0 deletions acvm_js/test/shared/memory_op.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// See `memory_op_circuit` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 146, 49, 14, 0, 32, 8, 3, 139, 192,
127, 240, 7, 254, 255, 85, 198, 136, 9, 131, 155, 48, 216, 165, 76, 77, 57,
80, 0, 140, 45, 117, 111, 238, 228, 179, 224, 174, 225, 110, 111, 234, 213,
185, 148, 156, 203, 121, 89, 86, 13, 215, 126, 131, 43, 153, 187, 115, 40,
185, 62, 153, 3, 136, 83, 60, 30, 96, 2, 12, 235, 225, 124, 14, 3, 0, 0,
]);

export const initialWitnessMap = new Map([
[1, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[2, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[3, "0x0000000000000000000000000000000000000000000000000000000000000002"],
]);

export const expectedWitnessMap = new Map([
[1, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[2, "0x0000000000000000000000000000000000000000000000000000000000000001"],
[3, "0x0000000000000000000000000000000000000000000000000000000000000002"],
[4, "0x0000000000000000000000000000000000000000000000000000000000000002"],
]);