From 51162e6b3c4fb83ee604df3142e24fd70ac756b7 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Tue, 6 Feb 2024 17:45:05 +0000 Subject: [PATCH] feat: allow brillig to read arrays directly from memory (#4460) This PR allows the `BrilligSolver` to read inputs directly from ACIR memory. This allows us to remove constraints which are generated purely to load values out of memory to pass into ACIR. Resolves https://github.com/noir-lang/noir/issues/4262 --- .../dsl/acir_format/serde/acir.hpp | 57 ++++++++++++++++++- noir/acvm-repo/acir/codegen/acir.cpp | 48 +++++++++++++++- noir/acvm-repo/acir/src/circuit/brillig.rs | 2 + .../src/circuit/opcodes/memory_operation.rs | 2 +- noir/acvm-repo/acvm/src/pwg/brillig.rs | 15 ++++- noir/acvm-repo/acvm/src/pwg/memory_op.rs | 4 +- noir/acvm-repo/acvm/src/pwg/mod.rs | 6 +- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 9 +-- 8 files changed, 129 insertions(+), 14 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index 7de4fe12cbe..5d9a5b6a5a6 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -343,7 +343,15 @@ struct BrilligInputs { static Array bincodeDeserialize(std::vector); }; - std::variant value; + struct MemoryArray { + Circuit::BlockId value; + + friend bool operator==(const MemoryArray&, const MemoryArray&); + std::vector bincodeSerialize() const; + static MemoryArray bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const BrilligInputs&, const BrilligInputs&); std::vector bincodeSerialize() const; @@ -4923,6 +4931,53 @@ Circuit::BrilligInputs::Array serde::Deserializable BrilligInputs::MemoryArray::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BrilligInputs::MemoryArray BrilligInputs::MemoryArray::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BrilligInputs::MemoryArray& obj, + Serializer& serializer) +{ + serde::Serializable::serialize(obj.value, serializer); +} + +template <> +template +Circuit::BrilligInputs::MemoryArray serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Circuit::BrilligInputs::MemoryArray obj; + obj.value = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Circuit { + inline bool operator==(const BrilligOpcode& lhs, const BrilligOpcode& rhs) { if (!(lhs.value == rhs.value)) { diff --git a/noir/acvm-repo/acir/codegen/acir.cpp b/noir/acvm-repo/acir/codegen/acir.cpp index 73c6054a603..18fc17f6926 100644 --- a/noir/acvm-repo/acir/codegen/acir.cpp +++ b/noir/acvm-repo/acir/codegen/acir.cpp @@ -318,7 +318,15 @@ namespace Circuit { static Array bincodeDeserialize(std::vector); }; - std::variant value; + struct MemoryArray { + Circuit::BlockId value; + + friend bool operator==(const MemoryArray&, const MemoryArray&); + std::vector bincodeSerialize() const; + static MemoryArray bincodeDeserialize(std::vector); + }; + + std::variant value; friend bool operator==(const BrilligInputs&, const BrilligInputs&); std::vector bincodeSerialize() const; @@ -4090,6 +4098,44 @@ Circuit::BrilligInputs::Array serde::Deserializable BrilligInputs::MemoryArray::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BrilligInputs::MemoryArray BrilligInputs::MemoryArray::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BrilligInputs::MemoryArray &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.value, serializer); +} + +template <> +template +Circuit::BrilligInputs::MemoryArray serde::Deserializable::deserialize(Deserializer &deserializer) { + Circuit::BrilligInputs::MemoryArray obj; + obj.value = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Circuit { inline bool operator==(const BrilligOpcode &lhs, const BrilligOpcode &rhs) { diff --git a/noir/acvm-repo/acir/src/circuit/brillig.rs b/noir/acvm-repo/acir/src/circuit/brillig.rs index 63c6ad2a3d4..1a636fa32bc 100644 --- a/noir/acvm-repo/acir/src/circuit/brillig.rs +++ b/noir/acvm-repo/acir/src/circuit/brillig.rs @@ -1,6 +1,7 @@ use crate::native_types::{Expression, Witness}; use brillig::Opcode as BrilligOpcode; use serde::{Deserialize, Serialize}; +use super::opcodes::BlockId; /// Inputs for the Brillig VM. These are the initial inputs /// that the Brillig VM will use to start. @@ -8,6 +9,7 @@ use serde::{Deserialize, Serialize}; pub enum BrilligInputs { Single(Expression), Array(Vec), + MemoryArray(BlockId) } /// Outputs for the Brillig VM. Once the VM has completed diff --git a/noir/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs b/noir/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs index 9e45dc4ee8c..0e94c0f051e 100644 --- a/noir/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs +++ b/noir/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs @@ -1,7 +1,7 @@ use crate::native_types::{Expression, Witness}; use serde::{Deserialize, Serialize}; -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)] pub struct BlockId(pub u32); /// Operation on a block of memory diff --git a/noir/acvm-repo/acvm/src/pwg/brillig.rs b/noir/acvm-repo/acvm/src/pwg/brillig.rs index 4b62f86f0eb..69e52b6dfb4 100644 --- a/noir/acvm-repo/acvm/src/pwg/brillig.rs +++ b/noir/acvm-repo/acvm/src/pwg/brillig.rs @@ -1,8 +1,9 @@ +use std::collections::HashMap; + use acir::{ brillig::{ForeignCallParam, ForeignCallResult, Value}, circuit::{ - brillig::{Brillig, BrilligInputs, BrilligOutputs}, - OpcodeLocation, + brillig::{Brillig, BrilligInputs, BrilligOutputs}, opcodes::BlockId, OpcodeLocation }, native_types::WitnessMap, FieldElement, @@ -12,7 +13,7 @@ use brillig_vm::{VMStatus, VM}; use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError}; -use super::{get_value, insert_value}; +use super::{get_value, insert_value, memory_op::MemoryOpSolver}; #[derive(Debug)] pub enum BrilligSolverStatus { @@ -64,6 +65,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { /// witness. pub(super) fn new( initial_witness: &WitnessMap, + memory: &HashMap, brillig: &'b Brillig, bb_solver: &'b B, acir_index: usize, @@ -96,6 +98,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { } } } + }, + BrilligInputs::MemoryArray(block_id) => { + let memory_block = memory.get(block_id).ok_or(OpcodeNotSolvable::MissingMemoryBlock(block_id.0))?; + for memory_index in 0..memory_block.block_len { + let memory_value = memory_block.block_value.get(&memory_index).expect("All memory is initialized on creation"); + calldata.push((*memory_value).into()); + } } } } diff --git a/noir/acvm-repo/acvm/src/pwg/memory_op.rs b/noir/acvm-repo/acvm/src/pwg/memory_op.rs index c1da2cd95cf..49ec652289e 100644 --- a/noir/acvm-repo/acvm/src/pwg/memory_op.rs +++ b/noir/acvm-repo/acvm/src/pwg/memory_op.rs @@ -14,8 +14,8 @@ type MemoryIndex = u32; /// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes. #[derive(Default)] pub(super) struct MemoryOpSolver { - block_value: HashMap, - block_len: u32, + pub(super) block_value: HashMap, + pub(super) block_len: u32, } impl MemoryOpSolver { diff --git a/noir/acvm-repo/acvm/src/pwg/mod.rs b/noir/acvm-repo/acvm/src/pwg/mod.rs index b6499c54224..d28f9134d98 100644 --- a/noir/acvm-repo/acvm/src/pwg/mod.rs +++ b/noir/acvm-repo/acvm/src/pwg/mod.rs @@ -79,6 +79,8 @@ pub enum StepResult<'a, B: BlackBoxFunctionSolver> { pub enum OpcodeNotSolvable { #[error("missing assignment for witness index {0}")] MissingAssignment(u32), + #[error("Attempted to load uninitialized memory block")] + MissingMemoryBlock(u32), #[error("expression has too many unknowns {0}")] ExpressionHasTooManyUnknowns(Expression), } @@ -336,7 +338,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { // there will be a cached `BrilligSolver` to avoid recomputation. let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() { Some(solver) => solver, - None => BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)?, + None => BrilligSolver::new(witness, &self.block_solvers, brillig, self.backend, self.instruction_pointer)?, }; match solver.solve()? { BrilligSolverStatus::ForeignCallWait(foreign_call) => { @@ -371,7 +373,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { return StepResult::Status(self.handle_opcode_resolution(resolution)); } - let solver = BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer); + let solver = BrilligSolver::new(witness, &self.block_solvers, brillig, self.backend, self.instruction_pointer); match solver { Ok(solver) => StepResult::IntoBrillig(solver), Err(..) => StepResult::Status(self.handle_opcode_resolution(solver.map(|_| ()))), diff --git a/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 19f5b208802..5f2a531aca4 100644 --- a/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/noir/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1453,10 +1453,8 @@ impl AcirContext { } Ok(BrilligInputs::Array(var_expressions)) } - AcirValue::DynamicArray(_) => { - let mut var_expressions = Vec::new(); - self.brillig_array_input(&mut var_expressions, i)?; - Ok(BrilligInputs::Array(var_expressions)) + AcirValue::DynamicArray(AcirDynamicArray { block_id,.. }) => { + Ok(BrilligInputs::MemoryArray(block_id)) } } })?; @@ -1870,6 +1868,9 @@ fn execute_brillig(code: &[BrilligOpcode], inputs: &[BrilligInputs]) -> Option { + return None; + } } }