diff --git a/acir/src/circuit/mod.rs b/acir/src/circuit/mod.rs index a6b0fdb07..6bfe9098e 100644 --- a/acir/src/circuit/mod.rs +++ b/acir/src/circuit/mod.rs @@ -16,6 +16,8 @@ const VERSION_NUMBER: u32 = 0; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct Circuit { + // current_witness_index is the highest witness index in the circuit. The next witness to be added to this circuit + // will take on this value. (The value is cached here as an optimization.) pub current_witness_index: u32, pub opcodes: Vec, pub public_inputs: PublicInputs, @@ -204,7 +206,7 @@ mod test { Opcode::Arithmetic(crate::native_types::Expression { mul_terms: vec![], linear_combinations: vec![], - q_c: FieldElement::from_hex("FFFF").unwrap(), + q_c: FieldElement::from(8u128), }), range_opcode(), and_opcode(), diff --git a/acir/src/native_types/witness.rs b/acir/src/native_types/witness.rs index d7753ec7c..b81eee65c 100644 --- a/acir/src/native_types/witness.rs +++ b/acir/src/native_types/witness.rs @@ -6,6 +6,7 @@ use flate2::{ }; use serde::{Deserialize, Serialize}; +// Witness might be a misnomer. This is an index that represents the position a witness will take #[derive( Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize, )] diff --git a/acvm/src/compiler.rs b/acvm/src/compiler.rs index 412921c77..bc27e875a 100644 --- a/acvm/src/compiler.rs +++ b/acvm/src/compiler.rs @@ -1,6 +1,6 @@ // The various passes that we can use over ACIR -pub mod fallback; -pub mod optimizer; +pub mod optimizers; +pub mod transformers; use crate::Language; use acir::{ @@ -9,10 +9,9 @@ use acir::{ BlackBoxFunc, }; use indexmap::IndexMap; -use optimizer::{CSatOptimizer, GeneralOptimizer}; +use optimizers::GeneralOptimizer; use thiserror::Error; - -use self::{fallback::IsBlackBoxSupported, optimizer::R1CSOptimizer}; +use transformers::{CSatTransformer, FallbackTransformer, IsBlackBoxSupported, R1CSTransformer}; #[derive(PartialEq, Eq, Debug, Error)] pub enum CompileError { @@ -29,33 +28,48 @@ pub fn compile( // Currently the optimizer and reducer are one in the same // for CSAT - // Fallback pass - let fallback = fallback::fallback(acir, is_black_box_supported)?; + // Fallback transformer pass + let acir = FallbackTransformer::transform(acir, is_black_box_supported)?; + + // General optimizer pass + let mut opcodes: Vec = Vec::new(); + for opcode in acir.opcodes { + match opcode { + Opcode::Arithmetic(arith_expr) => { + opcodes.push(Opcode::Arithmetic(GeneralOptimizer::optimize(arith_expr))) + } + other_gate => opcodes.push(other_gate), + }; + } + let acir = Circuit { opcodes, ..acir }; - let optimizer = match &np_language { + let transformer = match &np_language { crate::Language::R1CS => { - let optimizer = R1CSOptimizer::new(fallback); - return Ok(optimizer.optimize()); + let transformer = R1CSTransformer::new(acir); + return Ok(transformer.transform()); } - crate::Language::PLONKCSat { width } => CSatOptimizer::new(*width), + crate::Language::PLONKCSat { width } => CSatTransformer::new(*width), }; - // TODO: the code below is only for CSAT optimizer + // TODO: the code below is only for CSAT transformer // TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs // TODO or at the very least, we could put all of it inside of CSatOptimizer pass // Optimize the arithmetic gates by reducing them into the correct width and // creating intermediate variables when necessary - let mut optimized_gates = Vec::new(); + let mut transformed_gates = Vec::new(); - let mut next_witness_index = fallback.current_witness_index + 1; - for opcode in fallback.opcodes { + let mut next_witness_index = acir.current_witness_index + 1; + for opcode in acir.opcodes { match opcode { Opcode::Arithmetic(arith_expr) => { let mut intermediate_variables: IndexMap = IndexMap::new(); - let arith_expr = - optimizer.optimize(arith_expr, &mut intermediate_variables, next_witness_index); + let arith_expr = transformer.transform( + arith_expr, + &mut intermediate_variables, + next_witness_index, + ); // Update next_witness counter next_witness_index += intermediate_variables.len() as u32; @@ -67,10 +81,10 @@ pub fn compile( new_gates.push(arith_expr); new_gates.sort(); for gate in new_gates { - optimized_gates.push(Opcode::Arithmetic(gate)); + transformed_gates.push(Opcode::Arithmetic(gate)); } } - other_gate => optimized_gates.push(other_gate), + other_gate => transformed_gates.push(other_gate), } } @@ -78,7 +92,7 @@ pub fn compile( Ok(Circuit { current_witness_index, - opcodes: optimized_gates, - public_inputs: fallback.public_inputs, // The optimizer does not add public inputs + opcodes: transformed_gates, + public_inputs: acir.public_inputs, // The optimizer does not add public inputs }) } diff --git a/acvm/src/compiler/fallback.rs b/acvm/src/compiler/fallback.rs deleted file mode 100644 index 15d1e7e41..000000000 --- a/acvm/src/compiler/fallback.rs +++ /dev/null @@ -1,97 +0,0 @@ -use super::CompileError; -use acir::{ - circuit::{opcodes::BlackBoxFuncCall, Circuit, Opcode}, - native_types::Expression, - BlackBoxFunc, -}; - -// A predicate that returns true if the black box function is supported -pub type IsBlackBoxSupported = fn(&BlackBoxFunc) -> bool; - -//ACIR pass which replace unsupported opcodes using arithmetic fallback -pub fn fallback(acir: Circuit, is_supported: IsBlackBoxSupported) -> Result { - let mut acir_supported_opcodes = Vec::with_capacity(acir.opcodes.len()); - - let mut witness_idx = acir.current_witness_index + 1; - - for opcode in acir.opcodes { - let bb_func_call = match &opcode { - Opcode::Arithmetic(_) | Opcode::Directive(_) => { - // If it is not a black box function, then it is a directive or - // an arithmetic expression which are always supported - acir_supported_opcodes.push(opcode); - continue; - } - Opcode::BlackBoxFuncCall(bb_func_call) => { - // We know it is an black box function. Now check if it is - // supported by the backend. If it is supported, then we can simply - // collect the opcode - if is_supported(&bb_func_call.name) { - acir_supported_opcodes.push(opcode); - continue; - } - bb_func_call - } - }; - - // If we get here then we know that this black box function is not supported - // so we need to replace it with a version of the opcode which only uses arithmetic - // expressions - let (updated_witness_index, opcodes_fallback) = opcode_fallback(bb_func_call, witness_idx)?; - witness_idx = updated_witness_index; - - acir_supported_opcodes.extend(opcodes_fallback); - } - - Ok(Circuit { - current_witness_index: witness_idx, - opcodes: acir_supported_opcodes, - public_inputs: acir.public_inputs, - }) -} - -fn opcode_fallback( - bb_func_call: &BlackBoxFuncCall, - current_witness_idx: u32, -) -> Result<(u32, Vec), CompileError> { - let (updated_witness_index, opcodes_fallback) = match bb_func_call.name { - BlackBoxFunc::AND => { - let (lhs, rhs, result, num_bits) = - crate::pwg::logic::extract_input_output(bb_func_call); - stdlib::fallback::and( - Expression::from(&lhs), - Expression::from(&rhs), - result, - num_bits, - current_witness_idx, - ) - } - BlackBoxFunc::XOR => { - let (lhs, rhs, result, num_bits) = - crate::pwg::logic::extract_input_output(bb_func_call); - stdlib::fallback::xor( - Expression::from(&lhs), - Expression::from(&rhs), - result, - num_bits, - current_witness_idx, - ) - } - BlackBoxFunc::RANGE => { - // TODO: add consistency checks in one place - // TODO: we aren't checking that range gate should have one input - let input = &bb_func_call.inputs[0]; - // Note there are no outputs because range produces no outputs - stdlib::fallback::range( - Expression::from(&input.witness), - input.num_bits, - current_witness_idx, - ) - } - _ => { - return Err(CompileError::UnsupportedBlackBox(bb_func_call.name)); - } - }; - - Ok((updated_witness_index, opcodes_fallback)) -} diff --git a/acvm/src/compiler/optimizer/mod.rs b/acvm/src/compiler/optimizer/mod.rs deleted file mode 100644 index e84d7cb79..000000000 --- a/acvm/src/compiler/optimizer/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod csat_optimizer; -mod general_optimizer; -mod r1cs_optimizer; - -pub use csat_optimizer::Optimizer as CSatOptimizer; -pub use general_optimizer::GeneralOpt as GeneralOptimizer; -pub use r1cs_optimizer::R1CSOptimizer; diff --git a/acvm/src/compiler/optimizer/r1cs_optimizer.rs b/acvm/src/compiler/optimizer/r1cs_optimizer.rs deleted file mode 100644 index 01ecaffa1..000000000 --- a/acvm/src/compiler/optimizer/r1cs_optimizer.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::compiler::GeneralOptimizer; -use acir::circuit::{Circuit, Opcode}; - -pub struct R1CSOptimizer { - acir: Circuit, -} - -impl R1CSOptimizer { - pub fn new(acir: Circuit) -> Self { - Self { acir } - } - // R1CS optimizations uses the general optimizer. - // TODO: We could possibly make sure that all polynomials are at most degree-2 - pub fn optimize(self) -> Circuit { - let optimized_arith_gates: Vec<_> = self - .acir - .opcodes - .into_iter() - .map(|gate| match gate { - Opcode::Arithmetic(arith) => Opcode::Arithmetic(GeneralOptimizer::optimize(arith)), - other_gates => other_gates, - }) - .collect(); - - Circuit { - // The general optimizer may remove enough gates that a witness is no longer used - // however, we cannot decrement the number of witnesses, as that - // would require a linear scan over all gates in order to decrement all witness indices - // above the witness which was removed - current_witness_index: self.acir.current_witness_index, - opcodes: optimized_arith_gates, - public_inputs: self.acir.public_inputs, - } - } -} diff --git a/acvm/src/compiler/optimizer/general_optimizer.rs b/acvm/src/compiler/optimizers/general.rs similarity index 100% rename from acvm/src/compiler/optimizer/general_optimizer.rs rename to acvm/src/compiler/optimizers/general.rs diff --git a/acvm/src/compiler/optimizers/mod.rs b/acvm/src/compiler/optimizers/mod.rs new file mode 100644 index 000000000..2b7b95f28 --- /dev/null +++ b/acvm/src/compiler/optimizers/mod.rs @@ -0,0 +1,3 @@ +mod general; + +pub use general::GeneralOpt as GeneralOptimizer; diff --git a/acvm/src/compiler/optimizer/range_optimizer.rs b/acvm/src/compiler/optimizers/range.rs similarity index 100% rename from acvm/src/compiler/optimizer/range_optimizer.rs rename to acvm/src/compiler/optimizers/range.rs diff --git a/acvm/src/compiler/optimizer/csat_optimizer.rs b/acvm/src/compiler/transformers/csat.rs similarity index 96% rename from acvm/src/compiler/optimizer/csat_optimizer.rs rename to acvm/src/compiler/transformers/csat.rs index 454bcde42..4df9365d8 100644 --- a/acvm/src/compiler/optimizer/csat_optimizer.rs +++ b/acvm/src/compiler/transformers/csat.rs @@ -6,35 +6,30 @@ use acir::{ }; use indexmap::IndexMap; -use super::general_optimizer::GeneralOpt; -// Optimizer struct with all of the related optimizations to the arithmetic gate - // Is this more of a Reducer than an optimizer? // Should we give it all of the gates? -// Have a single optimizer that you instantiate with a width, then pass many gates through -pub struct Optimizer { +// Have a single transformer that you instantiate with a width, then pass many gates through +pub struct CSatTransformer { width: usize, } -impl Optimizer { - // Configure the width for the Optimizer - pub fn new(width: usize) -> Optimizer { +impl CSatTransformer { + // Configure the width for the optimizer + pub fn new(width: usize) -> CSatTransformer { assert!(width > 2); - Optimizer { width } + CSatTransformer { width } } // Still missing dead witness optimization. // To do this, we will need the whole set of arithmetic gates // I think it can also be done before the local optimization seen here, as dead variables will come from the user - pub fn optimize( + pub fn transform( &self, gate: Expression, intermediate_variables: &mut IndexMap, num_witness: u32, ) -> Expression { - let gate = GeneralOpt::optimize(gate); - // Here we create intermediate variables and constrain them to be equal to any subset of the polynomial that can be represented as a full gate let gate = self.full_gate_scan_optimization(gate, intermediate_variables, num_witness); // The last optimization to do is to create intermediate variables in order to flatten the fan-in and the amount of mul terms @@ -348,8 +343,9 @@ fn simple_reduction_smoke_test() { let num_witness = 4; - let optimizer = Optimizer::new(3); - let got_optimized_gate_a = optimizer.optimize(gate_a, &mut intermediate_variables, num_witness); + let optimizer = CSatTransformer::new(3); + let got_optimized_gate_a = + optimizer.transform(gate_a, &mut intermediate_variables, num_witness); // a = b + c + d => a - b - c - d = 0 // For width3, the result becomes: diff --git a/acvm/src/compiler/transformers/fallback.rs b/acvm/src/compiler/transformers/fallback.rs new file mode 100644 index 000000000..3dffcf9fe --- /dev/null +++ b/acvm/src/compiler/transformers/fallback.rs @@ -0,0 +1,103 @@ +use super::super::CompileError; +use acir::{ + circuit::{opcodes::BlackBoxFuncCall, Circuit, Opcode}, + native_types::Expression, + BlackBoxFunc, +}; + +// A predicate that returns true if the black box function is supported +pub type IsBlackBoxSupported = fn(&BlackBoxFunc) -> bool; + +pub struct FallbackTransformer; + +impl FallbackTransformer { + //ACIR pass which replace unsupported opcodes using arithmetic fallback + pub fn transform( + acir: Circuit, + is_supported: IsBlackBoxSupported, + ) -> Result { + let mut acir_supported_opcodes = Vec::with_capacity(acir.opcodes.len()); + + let mut witness_idx = acir.current_witness_index + 1; + + for opcode in acir.opcodes { + let bb_func_call = match &opcode { + Opcode::Arithmetic(_) | Opcode::Directive(_) => { + // If it is not a black box function, then it is a directive or + // an arithmetic expression which are always supported + acir_supported_opcodes.push(opcode); + continue; + } + Opcode::BlackBoxFuncCall(bb_func_call) => { + // We know it is an black box function. Now check if it is + // supported by the backend. If it is supported, then we can simply + // collect the opcode + if is_supported(&bb_func_call.name) { + acir_supported_opcodes.push(opcode); + continue; + } + bb_func_call + } + }; + + // If we get here then we know that this black box function is not supported + // so we need to replace it with a version of the opcode which only uses arithmetic + // expressions + let (updated_witness_index, opcodes_fallback) = + Self::opcode_fallback(bb_func_call, witness_idx)?; + witness_idx = updated_witness_index; + + acir_supported_opcodes.extend(opcodes_fallback); + } + + Ok(Circuit { + current_witness_index: witness_idx, + opcodes: acir_supported_opcodes, + public_inputs: acir.public_inputs, + }) + } + + fn opcode_fallback( + gc: &BlackBoxFuncCall, + current_witness_idx: u32, + ) -> Result<(u32, Vec), CompileError> { + let (updated_witness_index, opcodes_fallback) = match gc.name { + BlackBoxFunc::AND => { + let (lhs, rhs, result, num_bits) = crate::pwg::logic::extract_input_output(gc); + stdlib::fallback::and( + Expression::from(&lhs), + Expression::from(&rhs), + result, + num_bits, + current_witness_idx, + ) + } + BlackBoxFunc::XOR => { + let (lhs, rhs, result, num_bits) = crate::pwg::logic::extract_input_output(gc); + stdlib::fallback::xor( + Expression::from(&lhs), + Expression::from(&rhs), + result, + num_bits, + current_witness_idx, + ) + } + BlackBoxFunc::RANGE => { + // TODO: add consistency checks in one place + // TODO: we aren't checking that range gate should have one input + let input = &gc.inputs[0]; + // Note there are no outputs because range produces no outputs + stdlib::fallback::range( + Expression::from(&input.witness), + input.num_bits, + current_witness_idx, + ) + } + _ => { + return Err(CompileError::UnsupportedBlackBox(gc.name)); + } + }; + + Ok((updated_witness_index, opcodes_fallback)) + } +} diff --git a/acvm/src/compiler/transformers/mod.rs b/acvm/src/compiler/transformers/mod.rs new file mode 100644 index 000000000..52c3737dc --- /dev/null +++ b/acvm/src/compiler/transformers/mod.rs @@ -0,0 +1,8 @@ +mod csat; +mod fallback; +mod r1cs; + +pub use csat::CSatTransformer; +pub use fallback::FallbackTransformer; +pub use fallback::IsBlackBoxSupported; +pub use r1cs::R1CSTransformer; diff --git a/acvm/src/compiler/transformers/r1cs.rs b/acvm/src/compiler/transformers/r1cs.rs new file mode 100644 index 000000000..677bbb023 --- /dev/null +++ b/acvm/src/compiler/transformers/r1cs.rs @@ -0,0 +1,15 @@ +use acir::circuit::Circuit; + +pub struct R1CSTransformer { + acir: Circuit, +} + +impl R1CSTransformer { + pub fn new(acir: Circuit) -> Self { + Self { acir } + } + // TODO: We could possibly make sure that all polynomials are at most degree-2 + pub fn transform(self) -> Circuit { + self.acir + } +} diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index 134706897..e82fa9d4b 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -214,7 +214,7 @@ pub fn hash_constraint_system(cs: &Circuit) -> [u8; 32] { // by knowing the np complete language pub fn default_is_black_box_supported( language: Language, -) -> compiler::fallback::IsBlackBoxSupported { +) -> compiler::transformers::IsBlackBoxSupported { // R1CS does not support any of the black box functions by default. // The compiler will replace those that it can -- ie range, xor, and fn r1cs_is_supported(_opcode: &BlackBoxFunc) -> bool {