Skip to content

Commit

Permalink
simplify MSM with constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
guipublic committed Nov 29, 2024
1 parent aa143a7 commit 0e2afe9
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 38 deletions.
152 changes: 114 additions & 38 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement};

use crate::ssa::ir::instruction::BlackBoxFunc;
use crate::ssa::ir::{
basic_block::BasicBlockId,
dfg::{CallStack, DataFlowGraph},
instruction::{Instruction, SimplifyResult},
instruction::{Instruction, Intrinsic, SimplifyResult},
types::Type,
value::ValueId,
};
Expand Down Expand Up @@ -70,52 +71,127 @@ pub(super) fn simplify_msm(
block: BasicBlockId,
call_stack: &CallStack,
) -> SimplifyResult {
// TODO: Handle MSMs where a subset of the terms are constant.
let mut is_constant;

match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) {
(Some((points, _)), Some((scalars, _))) => {
let Some(points) = points
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};

let Some(scalars) = scalars
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};
// We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant.
let mut constant_points = vec![];
let mut constant_scalars_lo = vec![];
let mut constant_scalars_hi = vec![];
let mut var_points = vec![];
let mut var_scalars = vec![];
let len = scalars.len() / 2;
for i in 0..len {
match (
dfg.get_numeric_constant(scalars[2 * i]),
dfg.get_numeric_constant(scalars[2 * i + 1]),
dfg.get_numeric_constant(points[3 * i]),
dfg.get_numeric_constant(points[3 * i + 1]),
dfg.get_numeric_constant(points[3 * i + 2]),
) {
(Some(lo), Some(hi), _, _, _)
if lo == FieldElement::zero() && hi == FieldElement::zero() =>
{
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(_, _, _, _, Some(infinity)) if infinity == FieldElement::one() => {
is_constant = true;
constant_scalars_lo.push(FieldElement::zero());
constant_scalars_hi.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => {
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(x);
constant_points.push(y);
constant_points.push(infinity);
}
_ => {
is_constant = false;
}
}

let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (i, scalar) in scalars.into_iter().enumerate() {
if i % 2 == 0 {
scalars_lo.push(scalar);
} else {
scalars_hi.push(scalar);
if !is_constant {
var_points.push(points[3 * i]);
var_points.push(points[3 * i + 1]);
var_points.push(points[3 * i + 2]);
var_scalars.push(scalars[2 * i]);
var_scalars.push(scalars[2 * i + 1]);
}
}

let Ok((result_x, result_y, result_is_infinity)) =
solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
else {
// If there are no constant terms, we can't simplify
if constant_scalars_lo.is_empty() {
return SimplifyResult::None;
}
let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul(
&constant_points,
&constant_scalars_lo,
&constant_scalars_hi,
) else {
return SimplifyResult::None;
};

let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array =
dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone());

SimplifyResult::SimplifiedTo(result_array.first())
// If there are no variable term, we can directly return the constant result
if var_scalars.is_empty() {
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array = dfg.insert_instruction_and_results(
instruction,
block,
None,
call_stack.clone(),
);

return SimplifyResult::SimplifiedTo(result_array.first());
}
// If there is only one non-null constant term, we cannot simplify
if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() {
return SimplifyResult::None;
}
// Add the constant part back to the non-constant part, if it is not null
if result_is_infinity != FieldElement::one() {
let one = dfg.make_constant(FieldElement::one(), Type::field());
let zero = dfg.make_constant(FieldElement::zero(), Type::field());
var_scalars.push(one);
var_scalars.push(zero);
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());
var_points.push(result_x);
var_points.push(result_y);
var_points.push(result_is_infinity);
}
// Construct the simplified MSM expression
let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len());
let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ };
let scalars = dfg
.insert_instruction_and_results(scalars, block, None, call_stack.clone())
.first();
let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len());
let points = Instruction::MakeArray { elements: var_points.into(), typ };
let points =
dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first();
let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul));
SimplifyResult::SimplifiedToInstruction(Instruction::Call {
func: msm,
arguments: vec![points, scalars],
})
}
_ => SimplifyResult::None,
}
Expand Down
18 changes: 18 additions & 0 deletions test_programs/execution_success/embedded_curve_ops/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) {

// The results should be double the g1 point because the scalars are 1 and we pass in g1 twice
assert(double.x == res.x);

// Tests for #6549
let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 };
let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 };
let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[g1, double, pub_point, g1, g1],
[scalar, const_scalar1, scalar, const_scalar2, const_scalar3],
);
assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7);
// Check simplification by zero
let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true };
let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[zero_point, double, g1],
[scalar, const_zero, scalar],
);
assert(partial_mul == g1);
}

0 comments on commit 0e2afe9

Please sign in to comment.