Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa refactor): Speedup acir-gen #1793

Merged
merged 2 commits into from
Jun 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{errors::AcirGenError, generated_acir::GeneratedAcir};
use crate::brillig::brillig_gen::brillig_directive;
use crate::ssa_refactor::acir_gen::AcirValue;
use crate::ssa_refactor::ir::types::Type as SsaType;
use crate::ssa_refactor::ir::{instruction::Endian, map::TwoWayMap, types::NumericType};
use crate::ssa_refactor::ir::{instruction::Endian, types::NumericType};
use acvm::acir::{
brillig_vm::Opcode as BrilligOpcode,
circuit::brillig::{BrilligInputs, BrilligOutputs},
Expand All @@ -16,6 +16,7 @@ use acvm::{
FieldElement,
};
use iter_extended::vecmap;
use std::collections::HashMap;
use std::{borrow::Cow, hash::Hash};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -92,7 +93,7 @@ pub(crate) struct AcirContext {
/// Two-way map that links `AcirVar` to `AcirVarData`.
///
/// The vars object is an instance of the `TwoWayMap`, which provides a bidirectional mapping between `AcirVar` and `AcirVarData`.
vars: TwoWayMap<AcirVar, AcirVarData>,
vars: HashMap<AcirVar, AcirVarData>,

/// An in-memory representation of ACIR.
///
Expand Down Expand Up @@ -126,7 +127,7 @@ impl AcirContext {
///
/// Note: `Variables` are immutable.
pub(crate) fn neg_var(&mut self, var: AcirVar) -> AcirVar {
let var_data = &self.vars[var];
let var_data = &self.vars[&var];
let result_data = if let AcirVarData::Const(constant) = var_data {
AcirVarData::Const(-*constant)
} else {
Expand All @@ -138,7 +139,7 @@ impl AcirContext {
/// Adds a new Variable to context whose value will
/// be constrained to be the inverse of `var`.
pub(crate) fn inv_var(&mut self, var: AcirVar) -> Result<AcirVar, AcirGenError> {
let var_data = &self.vars[var];
let var_data = &self.vars[&var];
if let AcirVarData::Const(constant) = var_data {
// Note that this will return a 0 if the inverse is not available
let result_var = self.add_data(AcirVarData::Const(constant.inverse()));
Expand Down Expand Up @@ -179,8 +180,8 @@ impl AcirContext {
/// Returns an `AcirVar` that is `1` if `lhs` equals `rhs` and
/// 0 otherwise.
pub(crate) fn eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
Expand Down Expand Up @@ -245,8 +246,8 @@ impl AcirContext {
/// Constrains the `lhs` and `rhs` to be equal.
pub(crate) fn assert_eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<(), AcirGenError> {
// TODO: could use sub_var and then assert_eq_zero
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
if let (AcirVarData::Const(lhs_const), AcirVarData::Const(rhs_const)) = (lhs_data, rhs_data)
{
if lhs_const == rhs_const {
Expand Down Expand Up @@ -297,8 +298,8 @@ impl AcirContext {
/// Adds a new Variable to context whose value will
/// be constrained to be the multiplication of `lhs` and `rhs`
pub(crate) fn mul_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
let result = match (lhs_data, rhs_data) {
(AcirVarData::Witness(witness), AcirVarData::Expr(expr))
| (AcirVarData::Expr(expr), AcirVarData::Witness(witness)) => {
Expand Down Expand Up @@ -351,8 +352,8 @@ impl AcirContext {
/// Adds a new Variable to context whose value will
/// be constrained to be the addition of `lhs` and `rhs`
pub(crate) fn add_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
let result_data = if let (AcirVarData::Const(lhs_const), AcirVarData::Const(rhs_const)) =
(lhs_data, rhs_data)
{
Expand Down Expand Up @@ -388,7 +389,7 @@ impl AcirContext {
rhs: AcirVar,
_typ: AcirType,
) -> Result<AcirVar, AcirGenError> {
let rhs_data = &self.vars[rhs];
let rhs_data = &self.vars[&rhs];

// Compute 2^{rhs}
let two_pow_rhs = match rhs_data.as_constant() {
Expand All @@ -409,8 +410,8 @@ impl AcirContext {
) -> Result<(AcirVar, AcirVar), AcirGenError> {
let predicate = Expression::one();

let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
Expand Down Expand Up @@ -451,7 +452,7 @@ impl AcirContext {
rhs: AcirVar,
typ: AcirType,
) -> Result<AcirVar, AcirGenError> {
let rhs_data = &self.vars[rhs];
let rhs_data = &self.vars[&rhs];

// Compute 2^{rhs}
let two_pow_rhs = match rhs_data.as_constant() {
Expand Down Expand Up @@ -484,7 +485,7 @@ impl AcirContext {
variable: AcirVar,
numeric_type: &NumericType,
) -> Result<AcirVar, AcirGenError> {
let data = &self.vars[variable];
let data = &self.vars[&variable];
match numeric_type {
NumericType::Signed { .. } => todo!("signed integer constraining is unimplemented"),
NumericType::Unsigned { bit_size } => {
Expand All @@ -506,7 +507,7 @@ impl AcirContext {
rhs: u32,
max_bit_size: u32,
) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let lhs_data = &self.vars[&lhs];
let lhs_expr = lhs_data.to_expression();

let result_expr = self.acir_ir.truncate(&lhs_expr, rhs, max_bit_size)?;
Expand All @@ -523,8 +524,8 @@ impl AcirContext {
bit_size: u32,
predicate: Option<AcirVar>,
) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
Expand All @@ -533,7 +534,7 @@ impl AcirContext {
// TODO: The frontend should shout in this case

let predicate = predicate.map(|acir_var| {
let predicate_data = &self.vars[acir_var];
let predicate_data = &self.vars[&acir_var];
predicate_data.to_expression().into_owned()
});
let is_greater_than_eq =
Expand Down Expand Up @@ -573,7 +574,7 @@ impl AcirContext {
let domain_var =
inputs.pop().expect("ICE: Pedersen call requires domain separator").into_var();

let domain_constant = self.vars[domain_var]
let domain_constant = self.vars[&domain_var]
.as_constant()
.expect("ICE: Domain separator must be a constant");

Expand Down Expand Up @@ -606,7 +607,7 @@ impl AcirContext {
let mut witnesses = Vec::new();
for input in inputs {
for (input, typ) in input.flatten() {
let var_data = &self.vars[input];
let var_data = &self.vars[&input];

// Intrinsics only accept Witnesses. This is not a limitation of the
// intrinsics, its just how we have defined things. Ideally, we allow
Expand Down Expand Up @@ -638,12 +639,12 @@ impl AcirContext {
self.vars[&radix_var].as_constant().expect("ICE: radix should be a constant").to_u128()
as u32;

let limb_count = self.vars[limb_count_var]
let limb_count = self.vars[&limb_count_var]
.as_constant()
.expect("ICE: limb_size should be a constant")
.to_u128() as u32;

let input_expr = &self.vars[input_var].to_expression();
let input_expr = &self.vars[&input_var].to_expression();

let bit_size = u32::BITS - (radix - 1).leading_zeros();
let limbs = self.acir_ir.radix_le_decompose(input_expr, radix, limb_count, bit_size)?;
Expand Down Expand Up @@ -687,7 +688,7 @@ impl AcirContext {
let input = Self::flatten_values(input);

let witnesses = vecmap(input, |acir_var| {
let var_data = &self.vars[acir_var];
let var_data = &self.vars[&acir_var];
let expr = var_data.to_expression();
self.acir_ir.get_or_create_witness(&expr)
});
Expand Down Expand Up @@ -730,7 +731,8 @@ impl AcirContext {
/// either the key or the value.
fn add_data(&mut self, data: AcirVarData) -> AcirVar {
let id = AcirVar(self.vars.len());
self.vars.insert(id, data)
self.vars.insert(id, data);
id
}

pub(crate) fn brillig(
Expand Down Expand Up @@ -784,7 +786,7 @@ impl AcirContext {
fn brillig_array_input(&self, var_expressions: &mut Vec<Expression>, input: AcirValue) {
match input {
AcirValue::Var(var, _) => {
var_expressions.push(self.vars[var].to_expression().into_owned());
var_expressions.push(self.vars[&var].to_expression().into_owned());
}
AcirValue::Array(vars) => {
for var in vars {
Expand Down Expand Up @@ -844,18 +846,3 @@ impl AcirVarData {
/// A Reference to an `AcirVarData`
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) struct AcirVar(usize);

#[test]
fn repeat_op() {
let mut ctx = AcirContext::default();

let var_a = ctx.add_variable();
let var_b = ctx.add_variable();

// Multiplying the same variables twice should yield
// the same output.
let var_c = ctx.mul_var(var_a, var_b);
let should_be_var_c = ctx.mul_var(var_a, var_b);

assert_eq!(var_c, should_be_var_c);
}