From 0e2afe94685ef5f407fdc5f2316849932f152fb6 Mon Sep 17 00:00:00 2001 From: guipublic Date: Fri, 29 Nov 2024 11:43:13 +0000 Subject: [PATCH] simplify MSM with constant folding --- .../src/ssa/ir/instruction/call/blackbox.rs | 152 +++++++++++++----- .../embedded_curve_ops/src/main.nr | 18 +++ 2 files changed, 132 insertions(+), 38 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 4f2a31e2fb0..aacd408b44a 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement}; +use crate::ssa::ir::instruction::BlackBoxFunc; use crate::ssa::ir::{ basic_block::BasicBlockId, dfg::{CallStack, DataFlowGraph}, - instruction::{Instruction, SimplifyResult}, + instruction::{Instruction, Intrinsic, SimplifyResult}, types::Type, value::ValueId, }; @@ -70,52 +71,127 @@ pub(super) fn simplify_msm( block: BasicBlockId, call_stack: &CallStack, ) -> SimplifyResult { - // TODO: Handle MSMs where a subset of the terms are constant. + let mut is_constant; + match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) { (Some((points, _)), Some((scalars, _))) => { - let Some(points) = points - .into_iter() - .map(|id| dfg.get_numeric_constant(id)) - .collect::>>() - else { - return SimplifyResult::None; - }; - - let Some(scalars) = scalars - .into_iter() - .map(|id| dfg.get_numeric_constant(id)) - .collect::>>() - else { - return SimplifyResult::None; - }; + // We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant. + let mut constant_points = vec![]; + let mut constant_scalars_lo = vec![]; + let mut constant_scalars_hi = vec![]; + let mut var_points = vec![]; + let mut var_scalars = vec![]; + let len = scalars.len() / 2; + for i in 0..len { + match ( + dfg.get_numeric_constant(scalars[2 * i]), + dfg.get_numeric_constant(scalars[2 * i + 1]), + dfg.get_numeric_constant(points[3 * i]), + dfg.get_numeric_constant(points[3 * i + 1]), + dfg.get_numeric_constant(points[3 * i + 2]), + ) { + (Some(lo), Some(hi), _, _, _) + if lo == FieldElement::zero() && hi == FieldElement::zero() => + { + is_constant = true; + constant_scalars_lo.push(lo); + constant_scalars_hi.push(hi); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::one()); + } + (_, _, _, _, Some(infinity)) if infinity == FieldElement::one() => { + is_constant = true; + constant_scalars_lo.push(FieldElement::zero()); + constant_scalars_hi.push(FieldElement::zero()); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::zero()); + constant_points.push(FieldElement::one()); + } + (Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => { + is_constant = true; + constant_scalars_lo.push(lo); + constant_scalars_hi.push(hi); + constant_points.push(x); + constant_points.push(y); + constant_points.push(infinity); + } + _ => { + is_constant = false; + } + } - let mut scalars_lo = Vec::new(); - let mut scalars_hi = Vec::new(); - for (i, scalar) in scalars.into_iter().enumerate() { - if i % 2 == 0 { - scalars_lo.push(scalar); - } else { - scalars_hi.push(scalar); + if !is_constant { + var_points.push(points[3 * i]); + var_points.push(points[3 * i + 1]); + var_points.push(points[3 * i + 2]); + var_scalars.push(scalars[2 * i]); + var_scalars.push(scalars[2 * i + 1]); } } - let Ok((result_x, result_y, result_is_infinity)) = - solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi) - else { + // If there are no constant terms, we can't simplify + if constant_scalars_lo.is_empty() { + return SimplifyResult::None; + } + let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul( + &constant_points, + &constant_scalars_lo, + &constant_scalars_hi, + ) else { return SimplifyResult::None; }; - let result_x = dfg.make_constant(result_x, Type::field()); - let result_y = dfg.make_constant(result_y, Type::field()); - let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); - - let elements = im::vector![result_x, result_y, result_is_infinity]; - let typ = Type::Array(Arc::new(vec![Type::field()]), 3); - let instruction = Instruction::MakeArray { elements, typ }; - let result_array = - dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone()); - - SimplifyResult::SimplifiedTo(result_array.first()) + // If there are no variable term, we can directly return the constant result + if var_scalars.is_empty() { + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + + let elements = im::vector![result_x, result_y, result_is_infinity]; + let typ = Type::Array(Arc::new(vec![Type::field()]), 3); + let instruction = Instruction::MakeArray { elements, typ }; + let result_array = dfg.insert_instruction_and_results( + instruction, + block, + None, + call_stack.clone(), + ); + + return SimplifyResult::SimplifiedTo(result_array.first()); + } + // If there is only one non-null constant term, we cannot simplify + if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() { + return SimplifyResult::None; + } + // Add the constant part back to the non-constant part, if it is not null + if result_is_infinity != FieldElement::one() { + let one = dfg.make_constant(FieldElement::one(), Type::field()); + let zero = dfg.make_constant(FieldElement::zero(), Type::field()); + var_scalars.push(one); + var_scalars.push(zero); + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + var_points.push(result_x); + var_points.push(result_y); + var_points.push(result_is_infinity); + } + // Construct the simplified MSM expression + let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len()); + let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ }; + let scalars = dfg + .insert_instruction_and_results(scalars, block, None, call_stack.clone()) + .first(); + let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len()); + let points = Instruction::MakeArray { elements: var_points.into(), typ }; + let points = + dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first(); + let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)); + SimplifyResult::SimplifiedToInstruction(Instruction::Call { + func: msm, + arguments: vec![points, scalars], + }) } _ => SimplifyResult::None, } diff --git a/test_programs/execution_success/embedded_curve_ops/src/main.nr b/test_programs/execution_success/embedded_curve_ops/src/main.nr index e69184b9c96..85cf60dc796 100644 --- a/test_programs/execution_success/embedded_curve_ops/src/main.nr +++ b/test_programs/execution_success/embedded_curve_ops/src/main.nr @@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) { // The results should be double the g1 point because the scalars are 1 and we pass in g1 twice assert(double.x == res.x); + + // Tests for #6549 + let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 }; + let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 }; + let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 }; + let partial_mul = std::embedded_curve_ops::multi_scalar_mul( + [g1, double, pub_point, g1, g1], + [scalar, const_scalar1, scalar, const_scalar2, const_scalar3], + ); + assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7); + // Check simplification by zero + let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true }; + let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 }; + let partial_mul = std::embedded_curve_ops::multi_scalar_mul( + [zero_point, double, g1], + [scalar, const_zero, scalar], + ); + assert(partial_mul == g1); }