diff --git a/compiler/noirc_evaluator/src/lib.rs b/compiler/noirc_evaluator/src/lib.rs
index 8127e3d03ef..75ea557d3de 100644
--- a/compiler/noirc_evaluator/src/lib.rs
+++ b/compiler/noirc_evaluator/src/lib.rs
@@ -12,8 +12,7 @@ pub mod ssa;
pub use ssa::create_program;
pub use ssa::ir::instruction::ErrorType;
-/// Trims leading whitespace from each line of the input string, according to
-/// how much leading whitespace there is on the first non-empty line.
+/// Trims leading whitespace from each line of the input string
#[cfg(test)]
pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String {
let mut lines = src.trim_end().lines();
@@ -21,11 +20,10 @@ pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String {
while first_line.is_empty() {
first_line = lines.next().unwrap();
}
- let indent = first_line.len() - first_line.trim_start().len();
let mut result = first_line.trim_start().to_string();
for line in lines {
result.push('\n');
- result.push_str(&line[indent..]);
+ result.push_str(line.trim_start());
}
result
}
diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
index 301b75e0bd4..db085bd762f 100644
--- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
+++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
@@ -2,10 +2,11 @@ use std::sync::Arc;
use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement};
+use crate::ssa::ir::instruction::BlackBoxFunc;
use crate::ssa::ir::{
basic_block::BasicBlockId,
dfg::{CallStack, DataFlowGraph},
- instruction::{Instruction, SimplifyResult},
+ instruction::{Instruction, Intrinsic, SimplifyResult},
types::Type,
value::ValueId,
};
@@ -70,52 +71,125 @@ pub(super) fn simplify_msm(
block: BasicBlockId,
call_stack: &CallStack,
) -> SimplifyResult {
- // TODO: Handle MSMs where a subset of the terms are constant.
+ let mut is_constant;
+
match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) {
(Some((points, _)), Some((scalars, _))) => {
- let Some(points) = points
- .into_iter()
- .map(|id| dfg.get_numeric_constant(id))
- .collect::>>()
- else {
- return SimplifyResult::None;
- };
-
- let Some(scalars) = scalars
- .into_iter()
- .map(|id| dfg.get_numeric_constant(id))
- .collect:: >>()
- else {
- return SimplifyResult::None;
- };
+ // We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant.
+ let mut constant_points = vec![];
+ let mut constant_scalars_lo = vec![];
+ let mut constant_scalars_hi = vec![];
+ let mut var_points = vec![];
+ let mut var_scalars = vec![];
+ let len = scalars.len() / 2;
+ for i in 0..len {
+ match (
+ dfg.get_numeric_constant(scalars[2 * i]),
+ dfg.get_numeric_constant(scalars[2 * i + 1]),
+ dfg.get_numeric_constant(points[3 * i]),
+ dfg.get_numeric_constant(points[3 * i + 1]),
+ dfg.get_numeric_constant(points[3 * i + 2]),
+ ) {
+ (Some(lo), Some(hi), _, _, _) if lo.is_zero() && hi.is_zero() => {
+ is_constant = true;
+ constant_scalars_lo.push(lo);
+ constant_scalars_hi.push(hi);
+ constant_points.push(FieldElement::zero());
+ constant_points.push(FieldElement::zero());
+ constant_points.push(FieldElement::one());
+ }
+ (_, _, _, _, Some(infinity)) if infinity.is_one() => {
+ is_constant = true;
+ constant_scalars_lo.push(FieldElement::zero());
+ constant_scalars_hi.push(FieldElement::zero());
+ constant_points.push(FieldElement::zero());
+ constant_points.push(FieldElement::zero());
+ constant_points.push(FieldElement::one());
+ }
+ (Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => {
+ is_constant = true;
+ constant_scalars_lo.push(lo);
+ constant_scalars_hi.push(hi);
+ constant_points.push(x);
+ constant_points.push(y);
+ constant_points.push(infinity);
+ }
+ _ => {
+ is_constant = false;
+ }
+ }
- let mut scalars_lo = Vec::new();
- let mut scalars_hi = Vec::new();
- for (i, scalar) in scalars.into_iter().enumerate() {
- if i % 2 == 0 {
- scalars_lo.push(scalar);
- } else {
- scalars_hi.push(scalar);
+ if !is_constant {
+ var_points.push(points[3 * i]);
+ var_points.push(points[3 * i + 1]);
+ var_points.push(points[3 * i + 2]);
+ var_scalars.push(scalars[2 * i]);
+ var_scalars.push(scalars[2 * i + 1]);
}
}
- let Ok((result_x, result_y, result_is_infinity)) =
- solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
- else {
+ // If there are no constant terms, we can't simplify
+ if constant_scalars_lo.is_empty() {
+ return SimplifyResult::None;
+ }
+ let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul(
+ &constant_points,
+ &constant_scalars_lo,
+ &constant_scalars_hi,
+ ) else {
return SimplifyResult::None;
};
- let result_x = dfg.make_constant(result_x, Type::field());
- let result_y = dfg.make_constant(result_y, Type::field());
- let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());
-
- let elements = im::vector![result_x, result_y, result_is_infinity];
- let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
- let instruction = Instruction::MakeArray { elements, typ };
- let result_array =
- dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone());
-
- SimplifyResult::SimplifiedTo(result_array.first())
+ // If there are no variable term, we can directly return the constant result
+ if var_scalars.is_empty() {
+ let result_x = dfg.make_constant(result_x, Type::field());
+ let result_y = dfg.make_constant(result_y, Type::field());
+ let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());
+
+ let elements = im::vector![result_x, result_y, result_is_infinity];
+ let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
+ let instruction = Instruction::MakeArray { elements, typ };
+ let result_array = dfg.insert_instruction_and_results(
+ instruction,
+ block,
+ None,
+ call_stack.clone(),
+ );
+
+ return SimplifyResult::SimplifiedTo(result_array.first());
+ }
+ // If there is only one non-null constant term, we cannot simplify
+ if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() {
+ return SimplifyResult::None;
+ }
+ // Add the constant part back to the non-constant part, if it is not null
+ let one = dfg.make_constant(FieldElement::one(), Type::field());
+ let zero = dfg.make_constant(FieldElement::zero(), Type::field());
+ if result_is_infinity.is_zero() {
+ var_scalars.push(one);
+ var_scalars.push(zero);
+ let result_x = dfg.make_constant(result_x, Type::field());
+ let result_y = dfg.make_constant(result_y, Type::field());
+ let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());
+ var_points.push(result_x);
+ var_points.push(result_y);
+ var_points.push(result_is_infinity);
+ }
+ // Construct the simplified MSM expression
+ let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len() as u32);
+ let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ };
+ let scalars = dfg
+ .insert_instruction_and_results(scalars, block, None, call_stack.clone())
+ .first();
+ let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len() as u32);
+ let points = Instruction::MakeArray { elements: var_points.into(), typ };
+ let points =
+ dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first();
+ let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul));
+ SimplifyResult::SimplifiedToInstruction(Instruction::Call {
+ func: msm,
+ arguments: vec![points, scalars],
+ })
}
_ => SimplifyResult::None,
}
@@ -261,3 +335,93 @@ pub(super) fn simplify_signature(
_ => SimplifyResult::None,
}
}
+
+#[cfg(feature = "bn254")]
+#[cfg(test)]
+mod test {
+ use crate::ssa::opt::assert_normalized_ssa_equals;
+ use crate::ssa::Ssa;
+
+ #[cfg(feature = "bn254")]
+ #[test]
+ fn full_constant_folding() {
+ let src = r#"
+ acir(inline) fn main f0 {
+ b0():
+ v0 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4]
+ v1 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6]
+ v2 = call multi_scalar_mul (v1, v0) -> [Field; 3]
+ return v2
+ }"#;
+ let ssa = Ssa::from_str(src).unwrap();
+
+ let expected_src = r#"
+ acir(inline) fn main f0 {
+ b0():
+ v3 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4]
+ v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6]
+ v10 = make_array [Field 1478523918288173385110236399861791147958001875200066088686689589556927843200, Field 700144278551281040379388961242974992655630750193306467120985766322057145630, Field 0] : [Field; 3]
+ return v10
+ }
+ "#;
+ assert_normalized_ssa_equals(ssa, expected_src);
+ }
+
+ #[cfg(feature = "bn254")]
+ #[test]
+ fn simplify_zero() {
+ let src = r#"
+ acir(inline) fn main f0 {
+ b0(v0: Field, v1: Field):
+ v2 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6]
+ v3 = make_array [
+ Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9]
+ v4 = call multi_scalar_mul (v3, v2) -> [Field; 3]
+
+ return v4
+
+ }"#;
+ let ssa = Ssa::from_str(src).unwrap();
+ //First point is zero, second scalar is zero, so we should be left with the scalar mul of the last point.
+ let expected_src = r#"
+ acir(inline) fn main f0 {
+ b0(v0: Field, v1: Field):
+ v3 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6]
+ v5 = make_array [Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9]
+ v6 = make_array [v0, Field 0] : [Field; 2]
+ v7 = make_array [Field 1, v0, Field 0] : [Field; 3]
+ v9 = call multi_scalar_mul(v7, v6) -> [Field; 3]
+ return v9
+ }
+ "#;
+ assert_normalized_ssa_equals(ssa, expected_src);
+ }
+
+ #[cfg(feature = "bn254")]
+ #[test]
+ fn partial_constant_folding() {
+ let src = r#"
+ acir(inline) fn main f0 {
+ b0(v0: Field, v1: Field):
+ v2 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6]
+ v3 = make_array [
+ Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9]
+ v4 = call multi_scalar_mul (v3, v2) -> [Field; 3]
+ return v4
+ }"#;
+ let ssa = Ssa::from_str(src).unwrap();
+ //First and last scalar/point are constant, so we should be left with the msm of the middle point and the folded constant point
+ let expected_src = r#"
+ acir(inline) fn main f0 {
+ b0(v0: Field, v1: Field):
+ v5 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6]
+ v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9]
+ v8 = make_array [v0, Field 0, Field 1, Field 0] : [Field; 4]
+ v12 = make_array [v0, v1, Field 0, Field -3227352362257037263902424173275354266044964400219754872043023745437788450996, Field 8902249110305491597038405103722863701255802573786510474664632793109847672620, u1 0] : [Field; 6]
+ v14 = call multi_scalar_mul(v12, v8) -> [Field; 3]
+ return v14
+ }
+ "#;
+ assert_normalized_ssa_equals(ssa, expected_src);
+ }
+}
diff --git a/test_programs/execution_success/embedded_curve_ops/src/main.nr b/test_programs/execution_success/embedded_curve_ops/src/main.nr
index e69184b9c96..85cf60dc796 100644
--- a/test_programs/execution_success/embedded_curve_ops/src/main.nr
+++ b/test_programs/execution_success/embedded_curve_ops/src/main.nr
@@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) {
// The results should be double the g1 point because the scalars are 1 and we pass in g1 twice
assert(double.x == res.x);
+
+ // Tests for #6549
+ let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 };
+ let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 };
+ let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 };
+ let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
+ [g1, double, pub_point, g1, g1],
+ [scalar, const_scalar1, scalar, const_scalar2, const_scalar3],
+ );
+ assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7);
+ // Check simplification by zero
+ let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true };
+ let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 };
+ let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
+ [zero_point, double, g1],
+ [scalar, const_zero, scalar],
+ );
+ assert(partial_mul == g1);
}