Skip to content

Commit

Permalink
feat: allow brillig to read arrays directly from memory (#4460)
Browse files Browse the repository at this point in the history
This PR allows the `BrilligSolver` to read inputs directly from ACIR
memory. This allows us to remove constraints which are generated purely
to load values out of memory to pass into ACIR.

Resolves noir-lang/noir#4262
  • Loading branch information
TomAFrench authored Feb 6, 2024
1 parent d4a7716 commit f99392d
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 14 deletions.
57 changes: 56 additions & 1 deletion barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,15 @@ struct BrilligInputs {
static Array bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array> value;
struct MemoryArray {
Circuit::BlockId value;

friend bool operator==(const MemoryArray&, const MemoryArray&);
std::vector<uint8_t> bincodeSerialize() const;
static MemoryArray bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array, MemoryArray> value;

friend bool operator==(const BrilligInputs&, const BrilligInputs&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4923,6 +4931,53 @@ Circuit::BrilligInputs::Array serde::Deserializable<Circuit::BrilligInputs::Arra

namespace Circuit {

inline bool operator==(const BrilligInputs::MemoryArray& lhs, const BrilligInputs::MemoryArray& rhs)
{
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligInputs::MemoryArray::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligInputs::MemoryArray>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligInputs::MemoryArray BrilligInputs::MemoryArray::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligInputs::MemoryArray>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligInputs::MemoryArray>::serialize(const Circuit::BrilligInputs::MemoryArray& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligInputs::MemoryArray serde::Deserializable<Circuit::BrilligInputs::MemoryArray>::deserialize(
Deserializer& deserializer)
{
Circuit::BrilligInputs::MemoryArray obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode& lhs, const BrilligOpcode& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down
48 changes: 47 additions & 1 deletion noir/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,15 @@ namespace Circuit {
static Array bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array> value;
struct MemoryArray {
Circuit::BlockId value;

friend bool operator==(const MemoryArray&, const MemoryArray&);
std::vector<uint8_t> bincodeSerialize() const;
static MemoryArray bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Single, Array, MemoryArray> value;

friend bool operator==(const BrilligInputs&, const BrilligInputs&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4090,6 +4098,44 @@ Circuit::BrilligInputs::Array serde::Deserializable<Circuit::BrilligInputs::Arra
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligInputs::MemoryArray &lhs, const BrilligInputs::MemoryArray &rhs) {
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligInputs::MemoryArray::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligInputs::MemoryArray>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligInputs::MemoryArray BrilligInputs::MemoryArray::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligInputs::MemoryArray>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligInputs::MemoryArray>::serialize(const Circuit::BrilligInputs::MemoryArray &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligInputs::MemoryArray serde::Deserializable<Circuit::BrilligInputs::MemoryArray>::deserialize(Deserializer &deserializer) {
Circuit::BrilligInputs::MemoryArray obj;
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode &lhs, const BrilligOpcode &rhs) {
Expand Down
2 changes: 2 additions & 0 deletions noir/acvm-repo/acir/src/circuit/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::native_types::{Expression, Witness};
use brillig::Opcode as BrilligOpcode;
use serde::{Deserialize, Serialize};
use super::opcodes::BlockId;

/// Inputs for the Brillig VM. These are the initial inputs
/// that the Brillig VM will use to start.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
pub enum BrilligInputs {
Single(Expression),
Array(Vec<Expression>),
MemoryArray(BlockId)
}

/// Outputs for the Brillig VM. Once the VM has completed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::native_types::{Expression, Witness};
use serde::{Deserialize, Serialize};

#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)]
pub struct BlockId(pub u32);

/// Operation on a block of memory
Expand Down
15 changes: 12 additions & 3 deletions noir/acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::HashMap;

use acir::{
brillig::{ForeignCallParam, ForeignCallResult, Value},
circuit::{
brillig::{Brillig, BrilligInputs, BrilligOutputs},
OpcodeLocation,
brillig::{Brillig, BrilligInputs, BrilligOutputs}, opcodes::BlockId, OpcodeLocation
},
native_types::WitnessMap,
FieldElement,
Expand All @@ -12,7 +13,7 @@ use brillig_vm::{VMStatus, VM};

use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError};

use super::{get_value, insert_value};
use super::{get_value, insert_value, memory_op::MemoryOpSolver};

#[derive(Debug)]
pub enum BrilligSolverStatus {
Expand Down Expand Up @@ -64,6 +65,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
/// witness.
pub(super) fn new(
initial_witness: &WitnessMap,
memory: &HashMap<BlockId, MemoryOpSolver>,
brillig: &'b Brillig,
bb_solver: &'b B,
acir_index: usize,
Expand Down Expand Up @@ -96,6 +98,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
}
}
}
},
BrilligInputs::MemoryArray(block_id) => {
let memory_block = memory.get(block_id).ok_or(OpcodeNotSolvable::MissingMemoryBlock(block_id.0))?;
for memory_index in 0..memory_block.block_len {
let memory_value = memory_block.block_value.get(&memory_index).expect("All memory is initialized on creation");
calldata.push((*memory_value).into());
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions noir/acvm-repo/acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type MemoryIndex = u32;
/// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes.
#[derive(Default)]
pub(super) struct MemoryOpSolver {
block_value: HashMap<MemoryIndex, FieldElement>,
block_len: u32,
pub(super) block_value: HashMap<MemoryIndex, FieldElement>,
pub(super) block_len: u32,
}

impl MemoryOpSolver {
Expand Down
6 changes: 4 additions & 2 deletions noir/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ pub enum StepResult<'a, B: BlackBoxFunctionSolver> {
pub enum OpcodeNotSolvable {
#[error("missing assignment for witness index {0}")]
MissingAssignment(u32),
#[error("Attempted to load uninitialized memory block")]
MissingMemoryBlock(u32),
#[error("expression has too many unknowns {0}")]
ExpressionHasTooManyUnknowns(Expression),
}
Expand Down Expand Up @@ -336,7 +338,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
// there will be a cached `BrilligSolver` to avoid recomputation.
let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() {
Some(solver) => solver,
None => BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)?,
None => BrilligSolver::new(witness, &self.block_solvers, brillig, self.backend, self.instruction_pointer)?,
};
match solver.solve()? {
BrilligSolverStatus::ForeignCallWait(foreign_call) => {
Expand Down Expand Up @@ -371,7 +373,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
return StepResult::Status(self.handle_opcode_resolution(resolution));
}

let solver = BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer);
let solver = BrilligSolver::new(witness, &self.block_solvers, brillig, self.backend, self.instruction_pointer);
match solver {
Ok(solver) => StepResult::IntoBrillig(solver),
Err(..) => StepResult::Status(self.handle_opcode_resolution(solver.map(|_| ()))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1453,10 +1453,8 @@ impl AcirContext {
}
Ok(BrilligInputs::Array(var_expressions))
}
AcirValue::DynamicArray(_) => {
let mut var_expressions = Vec::new();
self.brillig_array_input(&mut var_expressions, i)?;
Ok(BrilligInputs::Array(var_expressions))
AcirValue::DynamicArray(AcirDynamicArray { block_id,.. }) => {
Ok(BrilligInputs::MemoryArray(block_id))
}
}
})?;
Expand Down Expand Up @@ -1870,6 +1868,9 @@ fn execute_brillig(code: &[BrilligOpcode], inputs: &[BrilligInputs]) -> Option<V
calldata.push(expr.to_const()?.into());
}
}
BrilligInputs::MemoryArray(_) => {
return None;
}
}
}

Expand Down

0 comments on commit f99392d

Please sign in to comment.