Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make ACVM generic across fields #5114

Merged
merged 10 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions acvm-repo/acir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ criterion.workspace = true
pprof.workspace = true

[features]
default = ["bn254"]
bn254 = ["acir_field/bn254", "brillig/bn254"]
bls12_381 = ["acir_field/bls12_381", "brillig/bls12_381"]
bn254 = ["acir_field/bn254"]
bls12_381 = ["acir_field/bls12_381"]

[[bench]]
name = "serialization"
Expand Down
8 changes: 4 additions & 4 deletions acvm-repo/acir/benches/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use pprof::criterion::{Output, PProfProfiler};

const SIZES: [usize; 9] = [10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000];

fn sample_program(num_opcodes: usize) -> Program {
let assert_zero_opcodes: Vec<Opcode> = (0..num_opcodes)
fn sample_program(num_opcodes: usize) -> Program<FieldElement> {
let assert_zero_opcodes: Vec<Opcode<_>> = (0..num_opcodes)
.map(|i| {
Opcode::AssertZero(Expression {
mul_terms: vec![(
Expand Down Expand Up @@ -83,7 +83,7 @@ fn bench_deserialization(c: &mut Criterion) {
BenchmarkId::from_parameter(size),
&serialized_program,
|b, program| {
b.iter(|| Program::deserialize_program(program));
b.iter(|| Program::<FieldElement>::deserialize_program(program));
},
);
}
Expand All @@ -107,7 +107,7 @@ fn bench_deserialization(c: &mut Criterion) {
|b, program| {
b.iter(|| {
let mut deserializer = serde_json::Deserializer::from_slice(program);
Program::deserialize_program_base64(&mut deserializer)
Program::<FieldElement>::deserialize_program_base64(&mut deserializer)
});
},
);
Expand Down
10 changes: 5 additions & 5 deletions acvm-repo/acir/src/circuit/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use serde::{Deserialize, Serialize};
/// Inputs for the Brillig VM. These are the initial inputs
/// that the Brillig VM will use to start.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
pub enum BrilligInputs {
Single(Expression),
Array(Vec<Expression>),
pub enum BrilligInputs<F> {
Single(Expression<F>),
Array(Vec<Expression<F>>),
MemoryArray(BlockId),
}

Expand All @@ -24,6 +24,6 @@ pub enum BrilligOutputs {
/// a full Brillig function to be executed by the Brillig VM.
/// This is stored separately on a program and accessed through a [BrilligPointer].
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Debug)]
pub struct BrilligBytecode {
pub bytecode: Vec<BrilligOpcode>,
pub struct BrilligBytecode<F> {
pub bytecode: Vec<BrilligOpcode<F>>,
}
4 changes: 2 additions & 2 deletions acvm-repo/acir/src/circuit/directives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
/// Directives do not apply any constraints.
/// You can think of them as opcodes that allow one to use non-determinism
/// In the future, this can be replaced with asm non-determinism blocks
pub enum Directive {
pub enum Directive<F> {
//decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix in little endian form
ToLeRadix { a: Expression, b: Vec<Witness>, radix: u32 },
ToLeRadix { a: Expression<F>, b: Vec<Witness>, radix: u32 },
}
94 changes: 50 additions & 44 deletions acvm-repo/acir/src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub mod directives;
pub mod opcodes;

use crate::native_types::{Expression, Witness};
use acir_field::FieldElement;
use acir_field::AcirField;
pub use opcodes::Opcode;
use thiserror::Error;

Expand Down Expand Up @@ -38,17 +38,17 @@ pub enum ExpressionWidth {
/// A program represented by multiple ACIR circuits. The execution trace of these
/// circuits is dictated by construction of the [crate::native_types::WitnessStack].
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct Program {
pub functions: Vec<Circuit>,
pub unconstrained_functions: Vec<BrilligBytecode>,
pub struct Program<F> {
pub functions: Vec<Circuit<F>>,
pub unconstrained_functions: Vec<BrilligBytecode<F>>,
}

#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct Circuit {
pub struct Circuit<F> {
// current_witness_index is the highest witness index in the circuit. The next witness to be added to this circuit
// will take on this value. (The value is cached here as an optimization.)
pub current_witness_index: u32,
pub opcodes: Vec<Opcode>,
pub opcodes: Vec<Opcode<F>>,
pub expression_width: ExpressionWidth,

/// The set of private inputs to the circuit.
Expand All @@ -67,7 +67,7 @@ pub struct Circuit {
// Note: This should be a BTreeMap, but serde-reflect is creating invalid
// c++ code at the moment when it is, due to OpcodeLocation needing a comparison
// implementation which is never generated.
pub assert_messages: Vec<(OpcodeLocation, AssertionPayload)>,
pub assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,

/// States whether the backend should use a SNARK recursion friendly prover.
/// If implemented by a backend, this means that proofs generated with this circuit
Expand All @@ -76,15 +76,15 @@ pub struct Circuit {
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExpressionOrMemory {
Expression(Expression),
pub enum ExpressionOrMemory<F> {
Expression(Expression<F>),
Memory(BlockId),
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AssertionPayload {
pub enum AssertionPayload<F> {
StaticString(String),
Dynamic(/* error_selector */ u64, Vec<ExpressionOrMemory>),
Dynamic(/* error_selector */ u64, Vec<ExpressionOrMemory<F>>),
}

#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
Expand Down Expand Up @@ -127,15 +127,15 @@ impl<'de> Deserialize<'de> for ErrorSelector {
pub const STRING_ERROR_SELECTOR: ErrorSelector = ErrorSelector(0);

#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct RawAssertionPayload {
pub struct RawAssertionPayload<F> {
pub selector: ErrorSelector,
pub data: Vec<FieldElement>,
pub data: Vec<F>,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ResolvedAssertionPayload {
pub enum ResolvedAssertionPayload<F> {
String(String),
Raw(RawAssertionPayload),
Raw(RawAssertionPayload<F>),
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -204,7 +204,7 @@ impl FromStr for OpcodeLocation {
}
}

impl Circuit {
impl<F: AcirField> Circuit<F> {
pub fn num_vars(&self) -> u32 {
self.current_witness_index + 1
}
Expand All @@ -223,7 +223,7 @@ impl Circuit {
}
}

impl Program {
impl<F: Serialize> Program<F> {
fn write<W: std::io::Write>(&self, writer: W) -> std::io::Result<()> {
let buf = bincode::serialize(self).unwrap();
let mut encoder = flate2::write::GzEncoder::new(writer, Compression::default());
Expand All @@ -232,36 +232,38 @@ impl Program {
Ok(())
}

fn read<R: std::io::Read>(reader: R) -> std::io::Result<Self> {
let mut gz_decoder = flate2::read::GzDecoder::new(reader);
let mut buf_d = Vec::new();
gz_decoder.read_to_end(&mut buf_d)?;
bincode::deserialize(&buf_d)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
}

pub fn serialize_program(program: &Program) -> Vec<u8> {
pub fn serialize_program(program: &Self) -> Vec<u8> {
let mut program_bytes: Vec<u8> = Vec::new();
program.write(&mut program_bytes).expect("expected circuit to be serializable");
program_bytes
}

pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
Program::read(serialized_circuit)
}

// Serialize and base64 encode program
pub fn serialize_program_base64<S>(program: &Program, s: S) -> Result<S::Ok, S::Error>
pub fn serialize_program_base64<S>(program: &Self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let program_bytes = Program::serialize_program(program);
let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
s.serialize_str(&encoded_b64)
}
}

impl<F: for<'a> Deserialize<'a>> Program<F> {
fn read<R: std::io::Read>(reader: R) -> std::io::Result<Self> {
let mut gz_decoder = flate2::read::GzDecoder::new(reader);
let mut buf_d = Vec::new();
gz_decoder.read_to_end(&mut buf_d)?;
bincode::deserialize(&buf_d)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
}

pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
Program::read(serialized_circuit)
}

// Deserialize and base64 decode program
pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Program, D::Error>
pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Expand All @@ -274,7 +276,7 @@ impl Program {
}
}

impl std::fmt::Display for Circuit {
impl<F: AcirField> std::fmt::Display for Circuit<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "current witness index : {}", self.current_witness_index)?;

Expand Down Expand Up @@ -313,13 +315,13 @@ impl std::fmt::Display for Circuit {
}
}

impl std::fmt::Debug for Circuit {
impl<F: AcirField> std::fmt::Debug for Circuit<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

impl std::fmt::Display for Program {
impl<F: AcirField> std::fmt::Display for Program<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (func_index, function) in self.functions.iter().enumerate() {
writeln!(f, "func {}", func_index)?;
Expand All @@ -333,7 +335,7 @@ impl std::fmt::Display for Program {
}
}

impl std::fmt::Debug for Program {
impl<F: AcirField> std::fmt::Debug for Program<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
Expand Down Expand Up @@ -365,21 +367,22 @@ mod tests {
circuit::{ExpressionWidth, Program},
native_types::Witness,
};
use acir_field::FieldElement;
use acir_field::{AcirField, FieldElement};
use serde::{Deserialize, Serialize};

fn and_opcode() -> Opcode {
fn and_opcode<F: AcirField>() -> Opcode<F> {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
lhs: FunctionInput { witness: Witness(1), num_bits: 4 },
rhs: FunctionInput { witness: Witness(2), num_bits: 4 },
output: Witness(3),
})
}
fn range_opcode() -> Opcode {
fn range_opcode<F: AcirField>() -> Opcode<F> {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(1), num_bits: 8 },
})
}
fn keccakf1600_opcode() -> Opcode {
fn keccakf1600_opcode<F: AcirField>() -> Opcode<F> {
let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput {
witness: Witness(i as u32 + 1),
num_bits: 8,
Expand All @@ -388,7 +391,7 @@ mod tests {

Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs })
}
fn schnorr_verify_opcode() -> Opcode {
fn schnorr_verify_opcode<F: AcirField>() -> Opcode<F> {
let public_key_x =
FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() };
let public_key_y =
Expand All @@ -413,7 +416,7 @@ mod tests {
let circuit = Circuit {
current_witness_index: 5,
expression_width: ExpressionWidth::Unbounded,
opcodes: vec![and_opcode(), range_opcode(), schnorr_verify_opcode()],
opcodes: vec![and_opcode::<FieldElement>(), range_opcode(), schnorr_verify_opcode()],
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])),
Expand All @@ -422,7 +425,9 @@ mod tests {
};
let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };

fn read_write(program: Program) -> (Program, Program) {
fn read_write<F: AcirField + Serialize + for<'a> Deserialize<'a>>(
program: Program<F>,
) -> (Program<F>, Program<F>) {
let bytes = Program::serialize_program(&program);
let got_program = Program::deserialize_program(&bytes).unwrap();
(program, got_program)
Expand Down Expand Up @@ -475,7 +480,8 @@ mod tests {
encoder.write_all(bad_circuit).unwrap();
encoder.finish().unwrap();

let deserialization_result = Program::deserialize_program(&zipped_bad_circuit);
let deserialization_result: Result<Program<FieldElement>, _> =
Program::deserialize_program(&zipped_bad_circuit);
assert!(deserialization_result.is_err());
}
}
Loading
Loading