From 6d0f86ba389a5b59b1d7fdcadcbce3e40eecaa48 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Fri, 6 Dec 2024 20:14:57 +0100 Subject: [PATCH] chore: simplify MSM with constant folding (#6650) --- compiler/noirc_evaluator/src/lib.rs | 6 +- .../src/ssa/ir/instruction/call/blackbox.rs | 240 +++++++++++++++--- .../embedded_curve_ops/src/main.nr | 18 ++ 3 files changed, 222 insertions(+), 42 deletions(-) diff --git a/compiler/noirc_evaluator/src/lib.rs b/compiler/noirc_evaluator/src/lib.rs index 8127e3d03ef..75ea557d3de 100644 --- a/compiler/noirc_evaluator/src/lib.rs +++ b/compiler/noirc_evaluator/src/lib.rs @@ -12,8 +12,7 @@ pub mod ssa; pub use ssa::create_program; pub use ssa::ir::instruction::ErrorType; -/// Trims leading whitespace from each line of the input string, according to -/// how much leading whitespace there is on the first non-empty line. +/// Trims leading whitespace from each line of the input string #[cfg(test)] pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String { let mut lines = src.trim_end().lines(); @@ -21,11 +20,10 @@ pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String { while first_line.is_empty() { first_line = lines.next().unwrap(); } - let indent = first_line.len() - first_line.trim_start().len(); let mut result = first_line.trim_start().to_string(); for line in lines { result.push('\n'); - result.push_str(&line[indent..]); + result.push_str(line.trim_start()); } result } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 301b75e0bd4..db085bd762f 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement}; +use crate::ssa::ir::instruction::BlackBoxFunc; use crate::ssa::ir::{ basic_block::BasicBlockId, dfg::{CallStack, DataFlowGraph}, - instruction::{Instruction, SimplifyResult}, + instruction::{Instruction, Intrinsic, SimplifyResult}, types::Type, value::ValueId, }; @@ -70,52 +71,125 @@ pub(super) fn simplify_msm( block: BasicBlockId, call_stack: &CallStack, ) -> SimplifyResult { - // TODO: Handle MSMs where a subset of the terms are constant. + let mut is_constant; + match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) { (Some((points, _)), Some((scalars, _))) => { - let Some(points) = points - .into_iter() - .map(|id| dfg.get_numeric_constant(id)) - .collect::>>() - else { - return SimplifyResult::None; - }; - - let Some(scalars) = scalars - .into_iter() - .map(|id| dfg.get_numeric_constant(id)) - .collect::>>() - else { - return SimplifyResult::None; - }; + // We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant. + let mut constant_points = vec![]; + let mut constant_scalars_lo = vec![]; + let mut constant_scalars_hi = vec![]; + let mut var_points = vec![]; + let mut var_scalars = vec![]; + let len = scalars.len() / 2; + for i in 0..len { + match ( + dfg.get_numeric_constant(scalars[2 * i]), + dfg.get_numeric_constant(scalars[2 * i + 1]), + dfg.get_numeric_constant(points[3 * i]), + dfg.get_numeric_constant(points[3 * i + 1]), + dfg.get_numeric_constant(points[3 * i + 2]), + ) { + (Some(lo), Some(hi), _, _, _) if lo.is_zero() && hi.is_zero() => { + is_constant = true; + constant_scalars_lo.push(lo); + constant_scalars_hi.push(hi); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::one()); + } + (_, _, _, _, Some(infinity)) if infinity.is_one() => { + is_constant = true; + constant_scalars_lo.push(FieldElement::zero()); + constant_scalars_hi.push(FieldElement::zero()); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::one()); + } + (Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => { + is_constant = true; + constant_scalars_lo.push(lo); + constant_scalars_hi.push(hi); + constant_points.push(x); + constant_points.push(y); + constant_points.push(infinity); + } + _ => { + is_constant = false; + } + } - let mut scalars_lo = Vec::new(); - let mut scalars_hi = Vec::new(); - for (i, scalar) in scalars.into_iter().enumerate() { - if i % 2 == 0 { - scalars_lo.push(scalar); - } else { - scalars_hi.push(scalar); + if !is_constant { + var_points.push(points[3 * i]); + var_points.push(points[3 * i + 1]); + var_points.push(points[3 * i + 2]); + var_scalars.push(scalars[2 * i]); + var_scalars.push(scalars[2 * i + 1]); } } - let Ok((result_x, result_y, result_is_infinity)) = - solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi) - else { + // If there are no constant terms, we can't simplify + if constant_scalars_lo.is_empty() { + return SimplifyResult::None; + } + let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul( + &constant_points, + &constant_scalars_lo, + &constant_scalars_hi, + ) else { return SimplifyResult::None; }; - let result_x = dfg.make_constant(result_x, Type::field()); - let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); - - let elements = im::vector![result_x, result_y, result_is_infinity]; - let typ = Type::Array(Arc::new(vec![Type::field()]), 3); - let instruction = Instruction::MakeArray { elements, typ }; - let result_array = - dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone()); - - SimplifyResult::SimplifiedTo(result_array.first()) + // If there are no variable term, we can directly return the constant result + if var_scalars.is_empty() { + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field()); + + let elements = im::vector![result_x, result_y, result_is_infinity]; + let typ = Type::Array(Arc::new(vec![Type::field()]), 3); + let instruction = Instruction::MakeArray { elements, typ }; + let result_array = dfg.insert_instruction_and_results( + instruction, + block, + None, + call_stack.clone(), + ); + + return SimplifyResult::SimplifiedTo(result_array.first()); + } + // If there is only one non-null constant term, we cannot simplify + if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() { + return SimplifyResult::None; + } + // Add the constant part back to the non-constant part, if it is not null + let one = dfg.make_constant(FieldElement::one(), Type::field()); + let zero = dfg.make_constant(FieldElement::zero(), Type::field()); + if result_is_infinity.is_zero() { + var_scalars.push(one); + var_scalars.push(zero); + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + var_points.push(result_x); + var_points.push(result_y); + var_points.push(result_is_infinity); + } + // Construct the simplified MSM expression + let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len() as u32); + let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ }; + let scalars = dfg + .insert_instruction_and_results(scalars, block, None, call_stack.clone()) + .first(); + let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len() as u32); + let points = Instruction::MakeArray { elements: var_points.into(), typ }; + let points = + dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first(); + let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)); + SimplifyResult::SimplifiedToInstruction(Instruction::Call { + func: msm, + arguments: vec![points, scalars], + }) } _ => SimplifyResult::None, } @@ -261,3 +335,93 @@ pub(super) fn simplify_signature( _ => SimplifyResult::None, } } + +#[cfg(feature = "bn254")] +#[cfg(test)] +mod test { + use crate::ssa::opt::assert_normalized_ssa_equals; + use crate::ssa::Ssa; + + #[cfg(feature = "bn254")] + #[test] + fn full_constant_folding() { + let src = r#" + acir(inline) fn main f0 { + b0(): + v0 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4] + v1 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6] + v2 = call multi_scalar_mul (v1, v0) -> [Field; 3] + return v2 + }"#; + let ssa = Ssa::from_str(src).unwrap(); + + let expected_src = r#" + acir(inline) fn main f0 { + b0(): + v3 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4] + v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6] + v10 = make_array [Field 1478523918288173385110236399861791147958001875200066088686689589556927843200, Field 700144278551281040379388961242974992655630750193306467120985766322057145630, Field 0] : [Field; 3] + return v10 + } + "#; + assert_normalized_ssa_equals(ssa, expected_src); + } + + #[cfg(feature = "bn254")] + #[test] + fn simplify_zero() { + let src = r#" + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v2 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6] + v3 = make_array [ + Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9] + v4 = call multi_scalar_mul (v3, v2) -> [Field; 3] + + return v4 + + }"#; + let ssa = Ssa::from_str(src).unwrap(); + //First point is zero, second scalar is zero, so we should be left with the scalar mul of the last point. + let expected_src = r#" + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6] + v5 = make_array [Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9] + v6 = make_array [v0, Field 0] : [Field; 2] + v7 = make_array [Field 1, v0, Field 0] : [Field; 3] + v9 = call multi_scalar_mul(v7, v6) -> [Field; 3] + return v9 + } + "#; + assert_normalized_ssa_equals(ssa, expected_src); + } + + #[cfg(feature = "bn254")] + #[test] + fn partial_constant_folding() { + let src = r#" + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v2 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6] + v3 = make_array [ + Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9] + v4 = call multi_scalar_mul (v3, v2) -> [Field; 3] + return v4 + }"#; + let ssa = Ssa::from_str(src).unwrap(); + //First and last scalar/point are constant, so we should be left with the msm of the middle point and the folded constant point + let expected_src = r#" + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v5 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6] + v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9] + v8 = make_array [v0, Field 0, Field 1, Field 0] : [Field; 4] + v12 = make_array [v0, v1, Field 0, Field -3227352362257037263902424173275354266044964400219754872043023745437788450996, Field 8902249110305491597038405103722863701255802573786510474664632793109847672620, u1 0] : [Field; 6] + v14 = call multi_scalar_mul(v12, v8) -> [Field; 3] + return v14 + } + "#; + assert_normalized_ssa_equals(ssa, expected_src); + } +} diff --git a/test_programs/execution_success/embedded_curve_ops/src/main.nr b/test_programs/execution_success/embedded_curve_ops/src/main.nr index e69184b9c96..85cf60dc796 100644 --- a/test_programs/execution_success/embedded_curve_ops/src/main.nr +++ b/test_programs/execution_success/embedded_curve_ops/src/main.nr @@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) { // The results should be double the g1 point because the scalars are 1 and we pass in g1 twice assert(double.x == res.x); + + // Tests for #6549 + let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 }; + let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 }; + let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 }; + let partial_mul = std::embedded_curve_ops::multi_scalar_mul( + [g1, double, pub_point, g1, g1], + [scalar, const_scalar1, scalar, const_scalar2, const_scalar3], + ); + assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7); + // Check simplification by zero + let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true }; + let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 }; + let partial_mul = std::embedded_curve_ops::multi_scalar_mul( + [zero_point, double, g1], + [scalar, const_zero, scalar], + ); + assert(partial_mul == g1); }