From b8bace9a00c3a8eb93f42682e8cbfa351fc5238c Mon Sep 17 00:00:00 2001 From: Aztec Bot <49558828+AztecBot@users.noreply.github.com> Date: Thu, 21 Nov 2024 08:03:43 -0500 Subject: [PATCH] feat: Sync from noir (#10047) Automated pull of development from the [noir](https://github.com/noir-lang/noir) programming language, a dependency of Aztec. BEGIN_COMMIT_OVERRIDE feat: trait aliases (https://github.com/noir-lang/noir/pull/6431) chore: Added test showcasing performance regression (https://github.com/noir-lang/noir/pull/6566) chore: embed package name in logs (https://github.com/noir-lang/noir/pull/6564) chore: remove separate acvm versioning (https://github.com/noir-lang/noir/pull/6561) chore: switch to 1.0.0-beta versioning (https://github.com/noir-lang/noir/pull/6503) chore: Release Noir(0.39.0) (https://github.com/noir-lang/noir/pull/6484) feat: Sync from aztec-packages (https://github.com/noir-lang/noir/pull/6557) feat(ssa): Unroll small loops in brillig (https://github.com/noir-lang/noir/pull/6505) fix: Do a shallow follow_bindings before unification (https://github.com/noir-lang/noir/pull/6558) chore: remove some `_else_condition` tech debt (https://github.com/noir-lang/noir/pull/6522) chore: revert #6375 (https://github.com/noir-lang/noir/pull/6552) feat: simplify constant MSM calls in SSA (https://github.com/noir-lang/noir/pull/6547) chore(test): Remove duplicate brillig tests (https://github.com/noir-lang/noir/pull/6523) chore: restructure `noirc_evaluator` crate (https://github.com/noir-lang/noir/pull/6534) fix: take blackbox function outputs into account when merging expressions (https://github.com/noir-lang/noir/pull/6532) chore: Add `Instruction::MakeArray` to SSA (https://github.com/noir-lang/noir/pull/6071) feat(profiler): Reduce memory in Brillig execution flamegraph (https://github.com/noir-lang/noir/pull/6538) chore: convert some tests to use SSA parser (https://github.com/noir-lang/noir/pull/6543) chore(ci): bump mac github runner image to `macos-14` (https://github.com/noir-lang/noir/pull/6545) chore(test): More descriptive labels in test matrix (https://github.com/noir-lang/noir/pull/6542) chore: Remove unused methods from implicit numeric generics (https://github.com/noir-lang/noir/pull/6541) fix: Fix poor handling of aliased references in flattening pass causing some values to be zeroed (https://github.com/noir-lang/noir/pull/6434) fix: allow range checks to be performed within the comptime intepreter (https://github.com/noir-lang/noir/pull/6514) fix: disallow `#[test]` on associated functions (https://github.com/noir-lang/noir/pull/6449) chore(ssa): Skip array_set pass for Brillig functions (https://github.com/noir-lang/noir/pull/6513) chore: Reverse ssa parser diff order (https://github.com/noir-lang/noir/pull/6511) chore: Parse negatives in SSA parser (https://github.com/noir-lang/noir/pull/6510) feat: avoid unnecessary ssa passes while loop unrolling (https://github.com/noir-lang/noir/pull/6509) fix(tests): Use a file lock as well as a mutex to isolate tests cases (https://github.com/noir-lang/noir/pull/6508) fix: set local_module before elaborating each trait (https://github.com/noir-lang/noir/pull/6506) fix: parse Slice type in SSa (https://github.com/noir-lang/noir/pull/6507) fix: perform arithmetic simplification through `CheckedCast` (https://github.com/noir-lang/noir/pull/6502) feat: SSA parser (https://github.com/noir-lang/noir/pull/6489) chore(test): Run test matrix on test_programs (https://github.com/noir-lang/noir/pull/6429) chore(ci): fix cargo deny (https://github.com/noir-lang/noir/pull/6501) feat: Deduplicate instructions across blocks (https://github.com/noir-lang/noir/pull/6499) chore: move tests for arithmetic generics closer to the code (https://github.com/noir-lang/noir/pull/6497) fix(docs): Fix broken links in oracles doc (https://github.com/noir-lang/noir/pull/6488) fix: Treat all parameters as possible aliases of each other (https://github.com/noir-lang/noir/pull/6477) chore: bump rust dependencies (https://github.com/noir-lang/noir/pull/6482) feat: use a full `BlackBoxFunctionSolver` implementation when execution brillig during acirgen (https://github.com/noir-lang/noir/pull/6481) chore(docs): Update How to Oracles (https://github.com/noir-lang/noir/pull/5675) chore: Release Noir(0.38.0) (https://github.com/noir-lang/noir/pull/6422) fix(ssa): Change array_set to not mutate slices coming from function inputs (https://github.com/noir-lang/noir/pull/6463) chore: update example to show how to split public inputs in bash (https://github.com/noir-lang/noir/pull/6472) fix: Discard optimisation that would change execution ordering or that is related to call outputs (https://github.com/noir-lang/noir/pull/6461) chore: proptest for `canonicalize` on infix type expressions (https://github.com/noir-lang/noir/pull/6269) fix: let formatter respect newlines between comments (https://github.com/noir-lang/noir/pull/6458) fix: check infix expression is valid in program input (https://github.com/noir-lang/noir/pull/6450) fix: don't crash on AsTraitPath with empty path (https://github.com/noir-lang/noir/pull/6454) fix(tests): Prevent EOF error while running test programs (https://github.com/noir-lang/noir/pull/6455) fix(sea): mem2reg to treat block input references as alias (https://github.com/noir-lang/noir/pull/6452) chore: revamp attributes (https://github.com/noir-lang/noir/pull/6424) feat!: Always Check Arithmetic Generics at Monomorphization (https://github.com/noir-lang/noir/pull/6329) chore: split path and import lookups (https://github.com/noir-lang/noir/pull/6430) fix(ssa): Resolve value IDs in terminator before comparing to array (https://github.com/noir-lang/noir/pull/6448) fix: right shift is not a regular division (https://github.com/noir-lang/noir/pull/6400) END_COMMIT_OVERRIDE --------- Co-authored-by: sirasistant Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> Co-authored-by: TomAFrench --- .noir-sync-commit | 2 +- .../compiler/noirc_evaluator/Cargo.toml | 2 +- .../acir_ir => acir}/acir_variable.rs | 39 +- .../{ssa/acir_gen/acir_ir => acir}/big_int.rs | 0 .../brillig_gen => acir}/brillig_directive.rs | 0 .../acir_ir => acir}/generated_acir.rs | 22 +- .../src/{ssa/acir_gen => acir}/mod.rs | 115 +- .../src/brillig/brillig_gen.rs | 1 - .../src/brillig/brillig_gen/brillig_block.rs | 87 +- .../brillig_gen/constant_allocation.rs | 5 +- .../brillig/brillig_gen/variable_liveness.rs | 27 +- .../src/brillig/brillig_ir/codegen_stack.rs | 2 + .../compiler/noirc_evaluator/src/lib.rs | 25 +- .../compiler/noirc_evaluator/src/ssa.rs | 22 +- .../src/ssa/acir_gen/acir_ir.rs | 3 - .../check_for_underconstrained_values.rs | 6 +- .../src/ssa/function_builder/data_bus.rs | 17 +- .../src/ssa/function_builder/mod.rs | 22 +- .../noirc_evaluator/src/ssa/ir/dfg.rs | 21 +- .../noirc_evaluator/src/ssa/ir/dom.rs | 2 +- .../noirc_evaluator/src/ssa/ir/function.rs | 8 + .../src/ssa/ir/function_inserter.rs | 108 +- .../noirc_evaluator/src/ssa/ir/instruction.rs | 59 +- .../src/ssa/ir/instruction/call.rs | 110 +- .../src/ssa/ir/instruction/call/blackbox.rs | 99 +- .../src/ssa/{ir.rs => ir/mod.rs} | 0 .../noirc_evaluator/src/ssa/ir/post_order.rs | 2 +- .../noirc_evaluator/src/ssa/ir/printer.rs | 33 +- .../noirc_evaluator/src/ssa/ir/value.rs | 4 - .../noirc_evaluator/src/ssa/opt/array_set.rs | 36 +- .../src/ssa/opt/constant_folding.rs | 30 +- .../noirc_evaluator/src/ssa/opt/die.rs | 43 +- .../src/ssa/opt/flatten_cfg.rs | 110 +- .../ssa/opt/flatten_cfg/capacity_tracker.rs | 32 +- .../src/ssa/opt/flatten_cfg/value_merger.rs | 14 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 4 - .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 64 +- .../noirc_evaluator/src/ssa/opt/mod.rs | 3 +- .../src/ssa/opt/normalize_value_ids.rs | 13 - .../noirc_evaluator/src/ssa/opt/rc.rs | 7 +- .../src/ssa/opt/remove_enable_side_effects.rs | 3 +- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 1328 +++++++++++++---- .../noirc_evaluator/src/ssa/parser/ast.rs | 8 +- .../src/ssa/parser/into_ssa.rs | 31 +- .../src/ssa/{parser.rs => parser/mod.rs} | 30 +- .../noirc_evaluator/src/ssa/parser/tests.rs | 29 +- .../noirc_evaluator/src/ssa/parser/token.rs | 3 + .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 10 +- .../compiler/noirc_frontend/src/ast/traits.rs | 16 + .../src/hir/type_check/errors.rs | 2 +- .../noirc_frontend/src/parser/errors.rs | 2 + .../noirc_frontend/src/parser/parser/impls.rs | 2 + .../noirc_frontend/src/parser/parser/item.rs | 72 +- .../src/parser/parser/parse_many.rs | 40 +- .../src/parser/parser/traits.rs | 282 +++- .../noirc_frontend/src/tests/traits.rs | 246 +++ .../docs/docs/noir/concepts/traits.md | 89 +- .../Nargo.toml | 6 + .../src/main.nr | 12 + .../Nargo.toml | 7 + .../Prover.toml | 16 + .../src/main.nr | 104 ++ noir/noir-repo/tooling/nargo_fmt/build.rs | 5 +- .../nargo_fmt/src/formatter/trait_impl.rs | 5 + .../tooling/nargo_fmt/src/formatter/traits.rs | 17 +- .../nargo_fmt/tests/expected/trait_alias.nr | 86 ++ .../nargo_fmt/tests/input/trait_alias.nr | 78 + 67 files changed, 2777 insertions(+), 951 deletions(-) rename noir/noir-repo/compiler/noirc_evaluator/src/{ssa/acir_gen/acir_ir => acir}/acir_variable.rs (99%) rename noir/noir-repo/compiler/noirc_evaluator/src/{ssa/acir_gen/acir_ir => acir}/big_int.rs (100%) rename noir/noir-repo/compiler/noirc_evaluator/src/{brillig/brillig_gen => acir}/brillig_directive.rs (100%) rename noir/noir-repo/compiler/noirc_evaluator/src/{ssa/acir_gen/acir_ir => acir}/generated_acir.rs (98%) rename noir/noir-repo/compiler/noirc_evaluator/src/{ssa/acir_gen => acir}/mod.rs (98%) delete mode 100644 noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir.rs rename noir/noir-repo/compiler/noirc_evaluator/src/ssa/{ir.rs => ir/mod.rs} (100%) rename noir/noir-repo/compiler/noirc_evaluator/src/ssa/{parser.rs => parser/mod.rs} (97%) create mode 100644 noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/Nargo.toml create mode 100644 noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/src/main.nr create mode 100644 noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Nargo.toml create mode 100644 noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Prover.toml create mode 100644 noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/src/main.nr create mode 100644 noir/noir-repo/tooling/nargo_fmt/tests/expected/trait_alias.nr create mode 100644 noir/noir-repo/tooling/nargo_fmt/tests/input/trait_alias.nr diff --git a/.noir-sync-commit b/.noir-sync-commit index 1de91bb1a7a..9bbde85e56b 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -13856a121125b1ccca15919942081a5d157d280e +68c32b4ffd9b069fe4b119327dbf4018c17ab9d4 diff --git a/noir/noir-repo/compiler/noirc_evaluator/Cargo.toml b/noir/noir-repo/compiler/noirc_evaluator/Cargo.toml index aac07339dd6..e25b5bf855a 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/Cargo.toml +++ b/noir/noir-repo/compiler/noirc_evaluator/Cargo.toml @@ -20,7 +20,6 @@ fxhash.workspace = true iter-extended.workspace = true thiserror.workspace = true num-bigint = "0.4" -num-traits.workspace = true im.workspace = true serde.workspace = true serde_json.workspace = true @@ -33,6 +32,7 @@ cfg-if.workspace = true [dev-dependencies] proptest.workspace = true similar-asserts.workspace = true +num-traits.workspace = true [features] bn254 = ["noirc_frontend/bn254"] diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/noir/noir-repo/compiler/noirc_evaluator/src/acir/acir_variable.rs similarity index 99% rename from noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/acir/acir_variable.rs index c6e4a261897..6ba072f01a4 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/acir/acir_variable.rs @@ -1,27 +1,18 @@ -use super::big_int::BigIntContext; -use super::generated_acir::{BrilligStdlibFunc, GeneratedAcir, PLACEHOLDER_BRILLIG_INDEX}; -use crate::brillig::brillig_gen::brillig_directive; -use crate::brillig::brillig_ir::artifact::GeneratedBrillig; -use crate::errors::{InternalBug, InternalError, RuntimeError, SsaReport}; -use crate::ssa::acir_gen::{AcirDynamicArray, AcirValue}; -use crate::ssa::ir::dfg::CallStack; -use crate::ssa::ir::types::Type as SsaType; -use crate::ssa::ir::{instruction::Endian, types::NumericType}; -use acvm::acir::circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}; -use acvm::acir::circuit::opcodes::{ - AcirFunctionId, BlockId, BlockType, ConstantOrWitnessEnum, MemOp, -}; -use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode}; -use acvm::brillig_vm::{MemoryValue, VMStatus, VM}; -use acvm::BlackBoxFunctionSolver; use acvm::{ - acir::AcirField, acir::{ brillig::Opcode as BrilligOpcode, - circuit::opcodes::FunctionInput, + circuit::{ + brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}, + opcodes::{ + AcirFunctionId, BlockId, BlockType, ConstantOrWitnessEnum, FunctionInput, MemOp, + }, + AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode, + }, native_types::{Expression, Witness}, - BlackBoxFunc, + AcirField, BlackBoxFunc, }, + brillig_vm::{MemoryValue, VMStatus, VM}, + BlackBoxFunctionSolver, }; use fxhash::FxHashMap as HashMap; use iter_extended::{try_vecmap, vecmap}; @@ -29,6 +20,16 @@ use num_bigint::BigUint; use std::cmp::Ordering; use std::{borrow::Cow, hash::Hash}; +use crate::brillig::brillig_ir::artifact::GeneratedBrillig; +use crate::errors::{InternalBug, InternalError, RuntimeError, SsaReport}; +use crate::ssa::ir::{ + dfg::CallStack, instruction::Endian, types::NumericType, types::Type as SsaType, +}; + +use super::big_int::BigIntContext; +use super::generated_acir::{BrilligStdlibFunc, GeneratedAcir, PLACEHOLDER_BRILLIG_INDEX}; +use super::{brillig_directive, AcirDynamicArray, AcirValue}; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] /// High level Type descriptor for Variables. /// diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/big_int.rs b/noir/noir-repo/compiler/noirc_evaluator/src/acir/big_int.rs similarity index 100% rename from noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/big_int.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/acir/big_int.rs diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs b/noir/noir-repo/compiler/noirc_evaluator/src/acir/brillig_directive.rs similarity index 100% rename from noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/acir/brillig_directive.rs diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/acir/generated_acir.rs similarity index 98% rename from noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/acir/generated_acir.rs index 6b215839f34..91206abe732 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/acir/generated_acir.rs @@ -1,22 +1,24 @@ //! `GeneratedAcir` is constructed as part of the `acir_gen` pass to accumulate all of the ACIR //! program as it is being converted from SSA form. -use std::{collections::BTreeMap, u32}; +use std::collections::BTreeMap; -use crate::{ - brillig::{brillig_gen::brillig_directive, brillig_ir::artifact::GeneratedBrillig}, - errors::{InternalError, RuntimeError, SsaReport}, - ssa::ir::{dfg::CallStack, instruction::ErrorType}, -}; use acvm::acir::{ circuit::{ brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}, opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode}, AssertionPayload, BrilligOpcodeLocation, ErrorSelector, OpcodeLocation, }, - native_types::Witness, - BlackBoxFunc, + native_types::{Expression, Witness}, + AcirField, BlackBoxFunc, +}; + +use super::brillig_directive; +use crate::{ + brillig::brillig_ir::artifact::GeneratedBrillig, + errors::{InternalError, RuntimeError, SsaReport}, + ssa::ir::dfg::CallStack, + ErrorType, }; -use acvm::{acir::native_types::Expression, acir::AcirField}; use iter_extended::vecmap; use noirc_errors::debug_info::ProcedureDebugId; @@ -155,7 +157,7 @@ impl GeneratedAcir { /// This means you cannot multiply an infinite amount of `Expression`s together. /// Once the `Expression` goes over degree-2, then it needs to be reduced to a `Witness` /// which has degree-1 in order to be able to continue the multiplication chain. - pub(crate) fn create_witness_for_expression(&mut self, expression: &Expression) -> Witness { + fn create_witness_for_expression(&mut self, expression: &Expression) -> Witness { let fresh_witness = self.next_witness_index(); // Create a constraint that sets them to be equal to each other diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/acir/mod.rs similarity index 98% rename from noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/acir/mod.rs index 33fdf2abc82..5c7899b5035 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/acir/mod.rs @@ -1,47 +1,58 @@ //! This file holds the pass to convert from Noir's SSA IR to ACIR. -mod acir_ir; +use fxhash::FxHashMap as HashMap; +use im::Vector; use std::collections::{BTreeMap, HashSet}; use std::fmt::Debug; -use self::acir_ir::acir_variable::{AcirContext, AcirType, AcirVar}; -use self::acir_ir::generated_acir::BrilligStdlibFunc; -use super::function_builder::data_bus::DataBus; -use super::ir::dfg::CallStack; -use super::ir::function::FunctionId; -use super::ir::instruction::ConstrainError; -use super::ir::printer::try_to_extract_string_from_error_payload; -use super::{ +use acvm::acir::{ + circuit::{ + brillig::{BrilligBytecode, BrilligFunctionId}, + opcodes::{AcirFunctionId, BlockType}, + AssertionPayload, ErrorSelector, ExpressionWidth, OpcodeLocation, + }, + native_types::Witness, + BlackBoxFunc, +}; +use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement}; +use bn254_blackbox_solver::Bn254BlackBoxSolver; +use iter_extended::{try_vecmap, vecmap}; +use noirc_frontend::monomorphization::ast::InlineType; + +mod acir_variable; +mod big_int; +mod brillig_directive; +mod generated_acir; + +use crate::brillig::{ + brillig_gen::brillig_fn::FunctionContext as BrilligFunctionContext, + brillig_ir::{ + artifact::{BrilligParameter, GeneratedBrillig}, + BrilligContext, + }, + Brillig, +}; +use crate::errors::{InternalError, InternalWarning, RuntimeError, SsaReport}; +use crate::ssa::{ + function_builder::data_bus::DataBus, ir::{ - dfg::DataFlowGraph, - function::{Function, RuntimeType}, + dfg::{CallStack, DataFlowGraph}, + function::{Function, FunctionId, RuntimeType}, instruction::{ - Binary, BinaryOp, Instruction, InstructionId, Intrinsic, TerminatorInstruction, + Binary, BinaryOp, ConstrainError, Instruction, InstructionId, Intrinsic, + TerminatorInstruction, }, map::Id, + printer::try_to_extract_string_from_error_payload, types::{NumericType, Type}, value::{Value, ValueId}, }, ssa_gen::Ssa, }; -use crate::brillig::brillig_ir::artifact::{BrilligParameter, GeneratedBrillig}; -use crate::brillig::brillig_ir::BrilligContext; -use crate::brillig::{brillig_gen::brillig_fn::FunctionContext as BrilligFunctionContext, Brillig}; -use crate::errors::{InternalError, InternalWarning, RuntimeError, SsaReport}; -pub(crate) use acir_ir::generated_acir::GeneratedAcir; -use acvm::acir::circuit::opcodes::{AcirFunctionId, BlockType}; -use bn254_blackbox_solver::Bn254BlackBoxSolver; -use noirc_frontend::monomorphization::ast::InlineType; - -use acvm::acir::circuit::brillig::{BrilligBytecode, BrilligFunctionId}; -use acvm::acir::circuit::{AssertionPayload, ErrorSelector, ExpressionWidth, OpcodeLocation}; -use acvm::acir::native_types::Witness; -use acvm::acir::BlackBoxFunc; -use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement}; -use fxhash::FxHashMap as HashMap; -use im::Vector; -use iter_extended::{try_vecmap, vecmap}; -use noirc_frontend::Type as HirType; +use acir_variable::{AcirContext, AcirType, AcirVar}; +use generated_acir::BrilligStdlibFunc; +pub(crate) use generated_acir::GeneratedAcir; +use noirc_frontend::hir_def::types::Type as HirType; #[derive(Default)] struct SharedContext { @@ -772,6 +783,12 @@ impl<'a> Context<'a> { Instruction::IfElse { .. } => { unreachable!("IfElse instruction remaining in acir-gen") } + Instruction::MakeArray { elements, typ: _ } => { + let elements = elements.iter().map(|element| self.convert_value(*element, dfg)); + let value = AcirValue::Array(elements.collect()); + let result = dfg.instruction_results(instruction_id)[0]; + self.ssa_values.insert(result, value); + } } self.acir_context.set_call_stack(CallStack::new()); @@ -1562,7 +1579,7 @@ impl<'a> Context<'a> { if !already_initialized { let value = &dfg[array]; match value { - Value::Array { .. } | Value::Instruction { .. } => { + Value::Instruction { .. } => { let value = self.convert_value(array, dfg); let array_typ = dfg.type_of_value(array); let len = if !array_typ.contains_slice_element() { @@ -1605,13 +1622,6 @@ impl<'a> Context<'a> { match array_typ { Type::Array(_, _) | Type::Slice(_) => { match &dfg[array_id] { - Value::Array { array, .. } => { - for (i, value) in array.iter().enumerate() { - flat_elem_type_sizes.push( - self.flattened_slice_size(*value, dfg) + flat_elem_type_sizes[i], - ); - } - } Value::Instruction { .. } | Value::Param { .. } => { // An instruction representing the slice means it has been processed previously during ACIR gen. // Use the previously defined result of an array operation to fetch the internal type information. @@ -1744,13 +1754,6 @@ impl<'a> Context<'a> { fn flattened_slice_size(&mut self, array_id: ValueId, dfg: &DataFlowGraph) -> usize { let mut size = 0; match &dfg[array_id] { - Value::Array { array, .. } => { - // The array is going to be the flattened outer array - // Flattened slice size from SSA value does not need to be multiplied by the len - for value in array { - size += self.flattened_slice_size(*value, dfg); - } - } Value::NumericConstant { .. } => { size += 1; } @@ -1914,10 +1917,6 @@ impl<'a> Context<'a> { Value::NumericConstant { constant, typ } => { AcirValue::Var(self.acir_context.add_constant(*constant), typ.into()) } - Value::Array { array, .. } => { - let elements = array.iter().map(|element| self.convert_value(*element, dfg)); - AcirValue::Array(elements.collect()) - } Value::Intrinsic(..) => todo!(), Value::Function(function_id) => { // This conversion is for debugging support only, to allow the @@ -2840,22 +2839,6 @@ impl<'a> Context<'a> { Ok(()) } - /// Given an array value, return the numerical type of its element. - /// Panics if the given value is not an array or has a non-numeric element type. - fn array_element_type(dfg: &DataFlowGraph, value: ValueId) -> AcirType { - match dfg.type_of_value(value) { - Type::Array(elements, _) => { - assert_eq!(elements.len(), 1); - (&elements[0]).into() - } - Type::Slice(elements) => { - assert_eq!(elements.len(), 1); - (&elements[0]).into() - } - _ => unreachable!("Expected array type"), - } - } - /// Convert a Vec into a Vec using the given result ids. /// If the type of a result id is an array, several acir vars are collected into /// a single AcirValue::Array of the same length. @@ -2946,9 +2929,9 @@ mod test { use std::collections::BTreeMap; use crate::{ + acir::BrilligStdlibFunc, brillig::Brillig, ssa::{ - acir_gen::acir_ir::generated_acir::BrilligStdlibFunc, function_builder::FunctionBuilder, ir::{function::FunctionId, instruction::BinaryOp, map::Id, types::Type}, }, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index 313fd65a197..786a03031d6 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -1,7 +1,6 @@ pub(crate) mod brillig_black_box; pub(crate) mod brillig_block; pub(crate) mod brillig_block_variables; -pub(crate) mod brillig_directive; pub(crate) mod brillig_fn; pub(crate) mod brillig_slice_ops; mod constant_allocation; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index 40224e132ab..36e1ee90e11 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -160,13 +160,9 @@ impl<'block> BrilligBlock<'block> { ); } TerminatorInstruction::Return { return_values, .. } => { - let return_registers: Vec<_> = return_values - .iter() - .map(|value_id| { - let return_variable = self.convert_ssa_value(*value_id, dfg); - return_variable.extract_register() - }) - .collect(); + let return_registers = vecmap(return_values, |value_id| { + self.convert_ssa_value(*value_id, dfg).extract_register() + }); self.brillig_context.codegen_return(&return_registers); } } @@ -763,6 +759,43 @@ impl<'block> BrilligBlock<'block> { Instruction::IfElse { .. } => { unreachable!("IfElse instructions should not be possible in brillig") } + Instruction::MakeArray { elements: array, typ } => { + let value_id = dfg.instruction_results(instruction_id)[0]; + if !self.variables.is_allocated(&value_id) { + let new_variable = self.variables.define_variable( + self.function_context, + self.brillig_context, + value_id, + dfg, + ); + + // Initialize the variable + match new_variable { + BrilligVariable::BrilligArray(brillig_array) => { + self.brillig_context.codegen_initialize_array(brillig_array); + } + BrilligVariable::BrilligVector(vector) => { + let size = self + .brillig_context + .make_usize_constant_instruction(array.len().into()); + self.brillig_context.codegen_initialize_vector(vector, size, None); + self.brillig_context.deallocate_single_addr(size); + } + _ => unreachable!( + "ICE: Cannot initialize array value created as {new_variable:?}" + ), + }; + + // Write the items + let items_pointer = self + .brillig_context + .codegen_make_array_or_vector_items_pointer(new_variable); + + self.initialize_constant_array(array, typ, dfg, items_pointer); + + self.brillig_context.deallocate_register(items_pointer); + } + } }; let dead_variables = self @@ -1500,46 +1533,6 @@ impl<'block> BrilligBlock<'block> { new_variable } } - Value::Array { array, typ } => { - if self.variables.is_allocated(&value_id) { - self.variables.get_allocation(self.function_context, value_id, dfg) - } else { - let new_variable = self.variables.define_variable( - self.function_context, - self.brillig_context, - value_id, - dfg, - ); - - // Initialize the variable - match new_variable { - BrilligVariable::BrilligArray(brillig_array) => { - self.brillig_context.codegen_initialize_array(brillig_array); - } - BrilligVariable::BrilligVector(vector) => { - let size = self - .brillig_context - .make_usize_constant_instruction(array.len().into()); - self.brillig_context.codegen_initialize_vector(vector, size, None); - self.brillig_context.deallocate_single_addr(size); - } - _ => unreachable!( - "ICE: Cannot initialize array value created as {new_variable:?}" - ), - }; - - // Write the items - let items_pointer = self - .brillig_context - .codegen_make_array_or_vector_items_pointer(new_variable); - - self.initialize_constant_array(array, typ, dfg, items_pointer); - - self.brillig_context.deallocate_register(items_pointer); - - new_variable - } - } Value::Function(_) => { // For the debugger instrumentation we want to allow passing // around values representing function pointers, even though diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs index f9ded224b33..61ca20be2f5 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs @@ -89,8 +89,7 @@ impl ConstantAllocation { } if let Some(terminator_instruction) = block.terminator() { terminator_instruction.for_each_value(|value_id| { - let variables = collect_variables_of_value(value_id, &func.dfg); - for variable in variables { + if let Some(variable) = collect_variables_of_value(value_id, &func.dfg) { record_if_constant(block_id, variable, InstructionLocation::Terminator); } }); @@ -166,7 +165,7 @@ impl ConstantAllocation { } pub(crate) fn is_constant_value(id: ValueId, dfg: &DataFlowGraph) -> bool { - matches!(&dfg[dfg.resolve(id)], Value::NumericConstant { .. } | Value::Array { .. }) + matches!(&dfg[dfg.resolve(id)], Value::NumericConstant { .. }) } /// For a given function, finds all the blocks that are within loops diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs index a18461bc0cd..87165c36dff 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs @@ -45,32 +45,19 @@ fn find_back_edges( } /// Collects the underlying variables inside a value id. It might be more than one, for example in constant arrays that are constructed with multiple vars. -pub(crate) fn collect_variables_of_value(value_id: ValueId, dfg: &DataFlowGraph) -> Vec { +pub(crate) fn collect_variables_of_value( + value_id: ValueId, + dfg: &DataFlowGraph, +) -> Option { let value_id = dfg.resolve(value_id); let value = &dfg[value_id]; match value { - Value::Instruction { .. } | Value::Param { .. } => { - vec![value_id] - } - // Literal arrays are constants, but might use variable values to initialize. - Value::Array { array, .. } => { - let mut value_ids = vec![value_id]; - - array.iter().for_each(|item_id| { - let underlying_ids = collect_variables_of_value(*item_id, dfg); - value_ids.extend(underlying_ids); - }); - - value_ids - } - Value::NumericConstant { .. } => { - vec![value_id] + Value::Instruction { .. } | Value::Param { .. } | Value::NumericConstant { .. } => { + Some(value_id) } // Functions are not variables in a defunctionalized SSA. Only constant function values should appear. - Value::ForeignFunction(_) | Value::Function(_) | Value::Intrinsic(..) => { - vec![] - } + Value::ForeignFunction(_) | Value::Function(_) | Value::Intrinsic(..) => None, } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs index a0e2a500e20..599c05fc0e8 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_stack.rs @@ -47,6 +47,7 @@ impl BrilligContext< let destinations_of_temp = movements_map.remove(first_source).unwrap(); movements_map.insert(temp_register, destinations_of_temp); } + // After removing loops we should have an DAG with each node having only one ancestor (but could have multiple successors) // Now we should be able to move the registers just by performing a DFS on the movements map let heads: Vec<_> = movements_map @@ -54,6 +55,7 @@ impl BrilligContext< .filter(|source| !destinations_set.contains(source)) .copied() .collect(); + for head in heads { self.perform_movements(&movements_map, head); } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/lib.rs b/noir/noir-repo/compiler/noirc_evaluator/src/lib.rs index 5f0c7a5bbb8..8127e3d03ef 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/lib.rs @@ -5,11 +5,9 @@ pub mod errors; -// SSA code to create the SSA based IR -// for functions and execute different optimizations. -pub mod ssa; - +mod acir; pub mod brillig; +pub mod ssa; pub use ssa::create_program; pub use ssa::ir::instruction::ErrorType; @@ -31,3 +29,22 @@ pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String { } result } + +/// Trim comments from the lines, ie. content starting with `//`. +#[cfg(test)] +pub(crate) fn trim_comments_from_lines(src: &str) -> String { + let mut result = String::new(); + let mut first = true; + for line in src.lines() { + if !first { + result.push('\n'); + } + if let Some(comment) = line.find("//") { + result.push_str(line[..comment].trim_end()); + } else { + result.push_str(line); + } + first = false; + } + result +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs index 5e2c0f0827d..9e11441caf4 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs @@ -31,19 +31,17 @@ use noirc_errors::debug_info::{DebugFunctions, DebugInfo, DebugTypes, DebugVaria use noirc_frontend::ast::Visibility; use noirc_frontend::{hir_def::function::FunctionSignature, monomorphization::ast::Program}; +use ssa_gen::Ssa; use tracing::{span, Level}; -use self::{ - acir_gen::{Artifacts, GeneratedAcir}, - ssa_gen::Ssa, -}; +use crate::acir::{Artifacts, GeneratedAcir}; -mod acir_gen; mod checks; pub(super) mod function_builder; pub mod ir; mod opt; -mod parser; +#[cfg(test)] +pub(crate) mod parser; pub mod ssa_gen; pub struct SsaEvaluatorOptions { @@ -96,28 +94,28 @@ pub(crate) fn optimize_into_acir( .run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:") .run_pass(Ssa::separate_runtime, "After Runtime Separation:") .run_pass(Ssa::resolve_is_unconstrained, "After Resolving IsUnconstrained:") - .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "After Inlining:") + .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "After Inlining (1st):") // Run mem2reg with the CFG separated into blocks - .run_pass(Ssa::mem2reg, "After Mem2Reg:") - .run_pass(Ssa::simplify_cfg, "After Simplifying:") + .run_pass(Ssa::mem2reg, "After Mem2Reg (1st):") + .run_pass(Ssa::simplify_cfg, "After Simplifying (1st):") .run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization") .try_run_pass( Ssa::evaluate_static_assert_and_assert_constant, "After `static_assert` and `assert_constant`:", )? .try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")? - .run_pass(Ssa::simplify_cfg, "After Simplifying:") + .run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):") .run_pass(Ssa::flatten_cfg, "After Flattening:") .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:") // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores - .run_pass(Ssa::mem2reg, "After Mem2Reg:") + .run_pass(Ssa::mem2reg, "After Mem2Reg (2nd):") // Run the inlining pass again to handle functions with `InlineType::NoPredicates`. // Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point. // This pass must come immediately following `mem2reg` as the succeeding passes // may create an SSA which inlining fails to handle. .run_pass( |ssa| ssa.inline_functions_with_no_predicates(options.inliner_aggressiveness), - "After Inlining:", + "After Inlining (2nd):", ) .run_pass(Ssa::remove_if_else, "After Remove IfElse:") .run_pass(Ssa::fold_constants, "After Constant Folding:") diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir.rs deleted file mode 100644 index 090d5bb0a83..00000000000 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) mod acir_variable; -pub(crate) mod big_int; -pub(crate) mod generated_acir; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs index 90eb79ccb69..cf884c98be9 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs @@ -191,7 +191,8 @@ impl Context { | Instruction::Load { .. } | Instruction::Not(..) | Instruction::Store { .. } - | Instruction::Truncate { .. } => { + | Instruction::Truncate { .. } + | Instruction::MakeArray { .. } => { self.value_sets.push(instruction_arguments_and_results); } @@ -247,8 +248,7 @@ impl Context { Value::ForeignFunction(..) => { panic!("Should not be able to reach foreign function from non-brillig functions, {func_id} in function {}", function.name()); } - Value::Array { .. } - | Value::Instruction { .. } + Value::Instruction { .. } | Value::NumericConstant { .. } | Value::Param { .. } => { panic!("At the point we are running disconnect there shouldn't be any other values as arguments") diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index 5a62e9c8e9a..e4a2eeb8c22 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, sync::Arc}; -use crate::ssa::ir::{types::Type, value::ValueId}; +use crate::ssa::ir::{function::RuntimeType, types::Type, value::ValueId}; use acvm::FieldElement; use fxhash::FxHashMap as HashMap; use noirc_frontend::ast; @@ -100,7 +100,8 @@ impl DataBus { ) -> DataBus { let mut call_data_args = Vec::new(); for call_data_item in call_data { - let array_id = call_data_item.databus.expect("Call data should have an array id"); + // databus can be None if `main` is a brillig function + let Some(array_id) = call_data_item.databus else { continue }; let call_data_id = call_data_item.call_data_id.expect("Call data should have a user id"); call_data_args.push(CallData { array_id, call_data_id, index_map: call_data_item.map }); @@ -161,13 +162,11 @@ impl FunctionBuilder { } let len = databus.values.len(); - let array = if len > 0 { - let array = self - .array_constant(databus.values, Type::Array(Arc::new(vec![Type::field()]), len)); - Some(array) - } else { - None - }; + let array = (len > 0 && matches!(self.current_function.runtime(), RuntimeType::Acir(_))) + .then(|| { + let array_type = Type::Array(Arc::new(vec![Type::field()]), len); + self.insert_make_array(databus.values, array_type) + }); DataBusBuilder { index: 0, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 63a9453a430..0479f8da0b7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -137,11 +137,6 @@ impl FunctionBuilder { self.numeric_constant(value.into(), Type::length_type()) } - /// Insert an array constant into the current function with the given element values. - pub(crate) fn array_constant(&mut self, elements: im::Vector, typ: Type) -> ValueId { - self.current_function.dfg.make_array(elements, typ) - } - /// Returns the type of the given value. pub(crate) fn type_of_value(&self, value: ValueId) -> Type { self.current_function.dfg.type_of_value(value) @@ -356,6 +351,17 @@ impl FunctionBuilder { self.insert_instruction(Instruction::EnableSideEffectsIf { condition }, None); } + /// Insert a `make_array` instruction to create a new array or slice. + /// Returns the new array value. Expects `typ` to be an array or slice type. + pub(crate) fn insert_make_array( + &mut self, + elements: im::Vector, + typ: Type, + ) -> ValueId { + assert!(matches!(typ, Type::Array(..) | Type::Slice(_))); + self.insert_instruction(Instruction::MakeArray { elements, typ }, None).first() + } + /// Terminates the current block with the given terminator instruction /// if the current block does not already have a terminator instruction. fn terminate_block_with(&mut self, terminator: TerminatorInstruction) { @@ -511,7 +517,6 @@ mod tests { instruction::{Endian, Intrinsic}, map::Id, types::Type, - value::Value, }; use super::FunctionBuilder; @@ -533,10 +538,7 @@ mod tests { let call_results = builder.insert_call(to_bits_id, vec![input, length], result_types).into_owned(); - let slice = match &builder.current_function.dfg[call_results[0]] { - Value::Array { array, .. } => array, - _ => panic!(), - }; + let slice = builder.current_function.dfg.get_array_constant(call_results[0]).unwrap().0; assert_eq!(slice[0], one); assert_eq!(slice[1], one); assert_eq!(slice[2], one); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index 2be9ffa9afa..e3f3f33682b 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -266,12 +266,6 @@ impl DataFlowGraph { id } - /// Create a new constant array value from the given elements - pub(crate) fn make_array(&mut self, array: im::Vector, typ: Type) -> ValueId { - assert!(matches!(typ, Type::Array(..) | Type::Slice(_))); - self.make_value(Value::Array { array, typ }) - } - /// Gets or creates a ValueId for the given FunctionId. pub(crate) fn import_function(&mut self, function: FunctionId) -> ValueId { if let Some(existing) = self.functions.get(&function) { @@ -458,8 +452,11 @@ impl DataFlowGraph { /// Otherwise, this returns None. pub(crate) fn get_array_constant(&self, value: ValueId) -> Option<(im::Vector, Type)> { match &self.values[self.resolve(value)] { + Value::Instruction { instruction, .. } => match &self.instructions[*instruction] { + Instruction::MakeArray { elements, typ } => Some((elements.clone(), typ.clone())), + _ => None, + }, // Arrays are shared, so cloning them is cheap - Value::Array { array, typ } => Some((array.clone(), typ.clone())), _ => None, } } @@ -522,8 +519,13 @@ impl DataFlowGraph { /// True if the given ValueId refers to a (recursively) constant value pub(crate) fn is_constant(&self, argument: ValueId) -> bool { match &self[self.resolve(argument)] { - Value::Instruction { .. } | Value::Param { .. } => false, - Value::Array { array, .. } => array.iter().all(|element| self.is_constant(*element)), + Value::Param { .. } => false, + Value::Instruction { instruction, .. } => match &self[*instruction] { + Instruction::MakeArray { elements, .. } => { + elements.iter().all(|element| self.is_constant(*element)) + } + _ => false, + }, _ => true, } } @@ -575,6 +577,7 @@ impl std::ops::IndexMut for DataFlowGraph { // The result of calling DataFlowGraph::insert_instruction can // be a list of results or a single ValueId if the instruction was simplified // to an existing value. +#[derive(Debug)] pub(crate) enum InsertInstructionResult<'dfg> { /// Results is the standard case containing the instruction id and the results of that instruction. Results(InstructionId, &'dfg [ValueId]), diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dom.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dom.rs index 94f7a405c05..c1a7f14e0d1 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dom.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dom.rs @@ -86,7 +86,7 @@ impl DominatorTree { /// /// This function panics if either of the blocks are unreachable. /// - /// An instruction is considered to dominate itself. + /// A block is considered to dominate itself. pub(crate) fn dominates(&mut self, block_a_id: BasicBlockId, block_b_id: BasicBlockId) -> bool { if let Some(res) = self.cache.get(&(block_a_id, block_b_id)) { return *res; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs index e8245ff6036..b1233e3063e 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -46,6 +46,14 @@ impl RuntimeType { | RuntimeType::Brillig(InlineType::NoPredicates) ) } + + pub(crate) fn is_brillig(&self) -> bool { + matches!(self, RuntimeType::Brillig(_)) + } + + pub(crate) fn is_acir(&self) -> bool { + matches!(self, RuntimeType::Acir(_)) + } } /// A function holds a list of instructions. diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs index 991ff22c902..5e133072067 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs @@ -18,15 +18,26 @@ pub(crate) struct FunctionInserter<'f> { pub(crate) function: &'f mut Function, values: HashMap, + /// Map containing repeat array constants so that we do not initialize a new /// array unnecessarily. An extra tuple field is included as part of the key to /// distinguish between array/slice types. - const_arrays: HashMap<(im::Vector, Type), ValueId>, + /// + /// This is optional since caching arrays relies on the inserter inserting strictly + /// in control-flow order. Otherwise, if arrays later in the program are cached first, + /// they may be refered to by instructions earlier in the program. + array_cache: Option, + + /// If this pass is loop unrolling, store the block before the loop to optionally + /// hoist any make_array instructions up to after they are retrieved from the `array_cache`. + pre_loop: Option, } +pub(crate) type ArrayCache = HashMap, HashMap>; + impl<'f> FunctionInserter<'f> { pub(crate) fn new(function: &'f mut Function) -> FunctionInserter<'f> { - Self { function, values: HashMap::default(), const_arrays: HashMap::default() } + Self { function, values: HashMap::default(), array_cache: None, pre_loop: None } } /// Resolves a ValueId to its new, updated value. @@ -36,27 +47,7 @@ impl<'f> FunctionInserter<'f> { value = self.function.dfg.resolve(value); match self.values.get(&value) { Some(value) => self.resolve(*value), - None => match &self.function.dfg[value] { - super::value::Value::Array { array, typ } => { - let array = array.clone(); - let typ = typ.clone(); - let new_array: im::Vector = - array.iter().map(|id| self.resolve(*id)).collect(); - - if let Some(fetched_value) = - self.const_arrays.get(&(new_array.clone(), typ.clone())) - { - return *fetched_value; - }; - - let new_array_clone = new_array.clone(); - let new_id = self.function.dfg.make_array(new_array, typ.clone()); - self.values.insert(value, new_id); - self.const_arrays.insert((new_array_clone, typ), new_id); - new_id - } - _ => value, - }, + None => value, } } @@ -80,6 +71,7 @@ impl<'f> FunctionInserter<'f> { } } + /// Get an instruction and make sure all the values in it are freshly resolved. pub(crate) fn map_instruction(&mut self, id: InstructionId) -> (Instruction, CallStack) { ( self.function.dfg[id].clone().map_values(|id| self.resolve(id)), @@ -122,7 +114,7 @@ impl<'f> FunctionInserter<'f> { &mut self, instruction: Instruction, id: InstructionId, - block: BasicBlockId, + mut block: BasicBlockId, call_stack: CallStack, ) -> InsertInstructionResult { let results = self.function.dfg.instruction_results(id); @@ -132,6 +124,30 @@ impl<'f> FunctionInserter<'f> { .requires_ctrl_typevars() .then(|| vecmap(&results, |result| self.function.dfg.type_of_value(*result))); + // Large arrays can lead to OOM panics if duplicated from being unrolled in loops. + // To prevent this, try to reuse the same ID for identical arrays instead of inserting + // another MakeArray instruction. Note that this assumes the function inserter is inserting + // in control-flow order. Otherwise we could refer to ValueIds defined later in the program. + let make_array = if let Instruction::MakeArray { elements, typ } = &instruction { + if self.array_is_constant(elements) { + if let Some(fetched_value) = self.get_cached_array(elements, typ) { + assert_eq!(results.len(), 1); + self.values.insert(results[0], fetched_value); + return InsertInstructionResult::SimplifiedTo(fetched_value); + } + + // Hoist constant arrays out of the loop and cache their value + if let Some(pre_loop) = self.pre_loop { + block = pre_loop; + } + Some((elements.clone(), typ.clone())) + } else { + None + } + } else { + None + }; + let new_results = self.function.dfg.insert_instruction_and_results( instruction, block, @@ -139,10 +155,54 @@ impl<'f> FunctionInserter<'f> { call_stack, ); + // Cache an array in the fresh_array_cache if array caching is enabled. + // The fresh cache isn't used for deduplication until an external pass confirms we + // pass a sequence point and all blocks that may be before the current insertion point + // are finished. + if let Some((elements, typ)) = make_array { + Self::cache_array(&mut self.array_cache, elements, typ, new_results.first()); + } + Self::insert_new_instruction_results(&mut self.values, &results, &new_results); new_results } + fn get_cached_array(&self, elements: &im::Vector, typ: &Type) -> Option { + self.array_cache.as_ref()?.get(elements)?.get(typ).copied() + } + + fn cache_array( + arrays: &mut Option, + elements: im::Vector, + typ: Type, + result_id: ValueId, + ) { + if let Some(arrays) = arrays { + arrays.entry(elements).or_default().insert(typ, result_id); + } + } + + fn array_is_constant(&self, elements: &im::Vector) -> bool { + elements.iter().all(|element| self.function.dfg.is_constant(*element)) + } + + pub(crate) fn set_array_cache( + &mut self, + new_cache: Option, + pre_loop: BasicBlockId, + ) { + self.array_cache = new_cache; + self.pre_loop = Some(pre_loop); + } + + /// Finish this inserter, returning its array cache merged with the fresh array cache. + /// Since this consumes the inserter this assumes we're at a sequence point where all + /// predecessor blocks to the current block are finished. Since this is true, the fresh + /// array cache can be merged with the existing array cache. + pub(crate) fn into_array_cache(self) -> Option { + self.array_cache + } + /// Modify the values HashMap to remember the mapping between an instruction result's previous /// ValueId (from the source_function) and its new ValueId in the destination function. pub(crate) fn insert_new_instruction_results( diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 2b40bccca7b..936dc854c51 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -278,6 +278,12 @@ pub(crate) enum Instruction { else_condition: ValueId, else_value: ValueId, }, + + /// Creates a new array or slice. + /// + /// `typ` should be an array or slice type with an element type + /// matching each of the `elements` values' types. + MakeArray { elements: im::Vector, typ: Type }, } impl Instruction { @@ -290,7 +296,9 @@ impl Instruction { pub(crate) fn result_type(&self) -> InstructionResultType { match self { Instruction::Binary(binary) => binary.result_type(), - Instruction::Cast(_, typ) => InstructionResultType::Known(typ.clone()), + Instruction::Cast(_, typ) | Instruction::MakeArray { typ, .. } => { + InstructionResultType::Known(typ.clone()) + } Instruction::Not(value) | Instruction::Truncate { value, .. } | Instruction::ArraySet { array: value, .. } @@ -344,6 +352,9 @@ impl Instruction { // We can deduplicate these instructions if we know the predicate is also the same. Constrain(..) | RangeCheck { .. } => deduplicate_with_predicate, + // This should never be side-effectful + MakeArray { .. } => true, + // These can have different behavior depending on the EnableSideEffectsIf context. // Replacing them with a similar instruction potentially enables replacing an instruction // with one that was disabled. See @@ -381,7 +392,8 @@ impl Instruction { | Load { .. } | ArrayGet { .. } | IfElse { .. } - | ArraySet { .. } => true, + | ArraySet { .. } + | MakeArray { .. } => true, Constrain(..) | Store { .. } @@ -444,7 +456,8 @@ impl Instruction { | Instruction::Store { .. } | Instruction::IfElse { .. } | Instruction::IncrementRc { .. } - | Instruction::DecrementRc { .. } => false, + | Instruction::DecrementRc { .. } + | Instruction::MakeArray { .. } => false, } } @@ -519,6 +532,10 @@ impl Instruction { else_value: f(*else_value), } } + Instruction::MakeArray { elements, typ } => Instruction::MakeArray { + elements: elements.iter().copied().map(f).collect(), + typ: typ.clone(), + }, } } @@ -579,6 +596,11 @@ impl Instruction { f(*else_condition); f(*else_value); } + Instruction::MakeArray { elements, typ: _ } => { + for element in elements { + f(*element); + } + } } } @@ -634,20 +656,28 @@ impl Instruction { None } } - Instruction::ArraySet { array, index, value, .. } => { - let array_const = dfg.get_array_constant(*array); - let index_const = dfg.get_numeric_constant(*index); - if let (Some((array, element_type)), Some(index)) = (array_const, index_const) { + Instruction::ArraySet { array: array_id, index: index_id, value, .. } => { + let array = dfg.get_array_constant(*array_id); + let index = dfg.get_numeric_constant(*index_id); + if let (Some((array, _element_type)), Some(index)) = (array, index) { let index = index.try_to_u32().expect("Expected array index to fit in u32") as usize; if index < array.len() { - let new_array = dfg.make_array(array.update(index, *value), element_type); - return SimplifiedTo(new_array); + let elements = array.update(index, *value); + let typ = dfg.type_of_value(*array_id); + let instruction = Instruction::MakeArray { elements, typ }; + let new_array = dfg.insert_instruction_and_results( + instruction, + block, + Option::None, + call_stack.clone(), + ); + return SimplifiedTo(new_array.first()); } } - try_optimize_array_set_from_previous_get(dfg, *array, *index, *value) + try_optimize_array_set_from_previous_get(dfg, *array_id, *index_id, *value) } Instruction::Truncate { value, bit_size, max_bit_size } => { if bit_size == max_bit_size { @@ -760,6 +790,7 @@ impl Instruction { None } } + Instruction::MakeArray { .. } => None, } } } @@ -803,13 +834,13 @@ fn try_optimize_array_get_from_previous_set( return SimplifyResult::None; } } + Instruction::MakeArray { elements: array, typ: _ } => { + elements = Some(array.clone()); + break; + } _ => return SimplifyResult::None, } } - Value::Array { array, typ: _ } => { - elements = Some(array.clone()); - break; - } _ => return SimplifyResult::None, } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 9dbd2c56993..e1e967b9a43 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -60,7 +60,7 @@ pub(super) fn simplify_call( } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") }; - constant_to_radix(endian, field, 2, limb_count, dfg) + constant_to_radix(endian, field, 2, limb_count, dfg, block, call_stack) } else { SimplifyResult::None } @@ -77,7 +77,7 @@ pub(super) fn simplify_call( } else { unreachable!("ICE: Intrinsic::ToRadix return type must be array") }; - constant_to_radix(endian, field, radix, limb_count, dfg) + constant_to_radix(endian, field, radix, limb_count, dfg, block, call_stack) } else { SimplifyResult::None } @@ -109,7 +109,8 @@ pub(super) fn simplify_call( let slice_length_value = array.len() / elements_size; let slice_length = dfg.make_constant(slice_length_value.into(), Type::length_type()); - let new_slice = dfg.make_array(array, Type::Slice(inner_element_types)); + let new_slice = + make_array(dfg, array, Type::Slice(inner_element_types), block, call_stack); SimplifyResult::SimplifiedToMultiple(vec![slice_length, new_slice]) } else { SimplifyResult::None @@ -129,7 +130,7 @@ pub(super) fn simplify_call( let new_slice_length = update_slice_length(arguments[0], dfg, BinaryOp::Add, block); - let new_slice = dfg.make_array(slice, element_type); + let new_slice = make_array(dfg, slice, element_type, block, call_stack); return SimplifyResult::SimplifiedToMultiple(vec![new_slice_length, new_slice]); } @@ -154,7 +155,7 @@ pub(super) fn simplify_call( let new_slice_length = update_slice_length(arguments[0], dfg, BinaryOp::Add, block); - let new_slice = dfg.make_array(slice, element_type); + let new_slice = make_array(dfg, slice, element_type, block, call_stack); SimplifyResult::SimplifiedToMultiple(vec![new_slice_length, new_slice]) } else { SimplifyResult::None @@ -196,7 +197,7 @@ pub(super) fn simplify_call( results.push(new_slice_length); - let new_slice = dfg.make_array(slice, typ); + let new_slice = make_array(dfg, slice, typ, block, call_stack); // The slice is the last item returned for pop_front results.push(new_slice); @@ -227,7 +228,7 @@ pub(super) fn simplify_call( let new_slice_length = update_slice_length(arguments[0], dfg, BinaryOp::Add, block); - let new_slice = dfg.make_array(slice, typ); + let new_slice = make_array(dfg, slice, typ, block, call_stack); SimplifyResult::SimplifiedToMultiple(vec![new_slice_length, new_slice]) } else { SimplifyResult::None @@ -260,7 +261,7 @@ pub(super) fn simplify_call( results.push(slice.remove(index)); } - let new_slice = dfg.make_array(slice, typ); + let new_slice = make_array(dfg, slice, typ, block, call_stack); results.insert(0, new_slice); let new_slice_length = update_slice_length(arguments[0], dfg, BinaryOp::Sub, block); @@ -317,7 +318,9 @@ pub(super) fn simplify_call( SimplifyResult::None } } - Intrinsic::BlackBox(bb_func) => simplify_black_box_func(bb_func, arguments, dfg), + Intrinsic::BlackBox(bb_func) => { + simplify_black_box_func(bb_func, arguments, dfg, block, call_stack) + } Intrinsic::AsField => { let instruction = Instruction::Cast( arguments[0], @@ -350,7 +353,7 @@ pub(super) fn simplify_call( Intrinsic::IsUnconstrained => SimplifyResult::None, Intrinsic::DerivePedersenGenerators => { if let Some(Type::Array(_, len)) = ctrl_typevars.unwrap().first() { - simplify_derive_generators(dfg, arguments, *len as u32) + simplify_derive_generators(dfg, arguments, *len as u32, block, call_stack) } else { unreachable!("Derive Pedersen Generators must return an array"); } @@ -419,7 +422,7 @@ fn simplify_slice_push_back( } let slice_size = slice.len(); let element_size = element_type.element_size(); - let new_slice = dfg.make_array(slice, element_type); + let new_slice = make_array(dfg, slice, element_type, block, &call_stack); let set_last_slice_value_instr = Instruction::ArraySet { array: new_slice, @@ -505,6 +508,8 @@ fn simplify_black_box_func( bb_func: BlackBoxFunc, arguments: &[ValueId], dfg: &mut DataFlowGraph, + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { cfg_if::cfg_if! { if #[cfg(feature = "bn254")] { @@ -514,8 +519,12 @@ fn simplify_black_box_func( } }; match bb_func { - BlackBoxFunc::Blake2s => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake2s), - BlackBoxFunc::Blake3 => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake3), + BlackBoxFunc::Blake2s => { + simplify_hash(dfg, arguments, acvm::blackbox_solver::blake2s, block, call_stack) + } + BlackBoxFunc::Blake3 => { + simplify_hash(dfg, arguments, acvm::blackbox_solver::blake3, block, call_stack) + } BlackBoxFunc::Keccakf1600 => { if let Some((array_input, _)) = dfg.get_array_constant(arguments[0]) { if array_is_constant(dfg, &array_input) { @@ -533,8 +542,14 @@ fn simplify_black_box_func( const_input.try_into().expect("Keccakf1600 input should have length of 25"), ) .expect("Rust solvable black box function should not fail"); - let state_values = vecmap(state, |x| FieldElement::from(x as u128)); - let result_array = make_constant_array(dfg, state_values, Type::unsigned(64)); + let state_values = state.iter().map(|x| FieldElement::from(*x as u128)); + let result_array = make_constant_array( + dfg, + state_values, + Type::unsigned(64), + block, + call_stack, + ); SimplifyResult::SimplifiedTo(result_array) } else { SimplifyResult::None @@ -544,7 +559,7 @@ fn simplify_black_box_func( } } BlackBoxFunc::Poseidon2Permutation => { - blackbox::simplify_poseidon2_permutation(dfg, solver, arguments) + blackbox::simplify_poseidon2_permutation(dfg, solver, arguments, block, call_stack) } BlackBoxFunc::EcdsaSecp256k1 => blackbox::simplify_signature( dfg, @@ -557,8 +572,12 @@ fn simplify_black_box_func( acvm::blackbox_solver::ecdsa_secp256r1_verify, ), - BlackBoxFunc::MultiScalarMul => SimplifyResult::None, - BlackBoxFunc::EmbeddedCurveAdd => blackbox::simplify_ec_add(dfg, solver, arguments), + BlackBoxFunc::MultiScalarMul => { + blackbox::simplify_msm(dfg, solver, arguments, block, call_stack) + } + BlackBoxFunc::EmbeddedCurveAdd => { + blackbox::simplify_ec_add(dfg, solver, arguments, block, call_stack) + } BlackBoxFunc::SchnorrVerify => blackbox::simplify_schnorr_verify(dfg, solver, arguments), BlackBoxFunc::BigIntAdd @@ -585,23 +604,47 @@ fn simplify_black_box_func( } } -fn make_constant_array(dfg: &mut DataFlowGraph, results: Vec, typ: Type) -> ValueId { - let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); +fn make_constant_array( + dfg: &mut DataFlowGraph, + results: impl Iterator, + typ: Type, + block: BasicBlockId, + call_stack: &CallStack, +) -> ValueId { + let result_constants: im::Vector<_> = + results.map(|element| dfg.make_constant(element, typ.clone())).collect(); let typ = Type::Array(Arc::new(vec![typ]), result_constants.len()); - dfg.make_array(result_constants.into(), typ) + make_array(dfg, result_constants, typ, block, call_stack) +} + +fn make_array( + dfg: &mut DataFlowGraph, + elements: im::Vector, + typ: Type, + block: BasicBlockId, + call_stack: &CallStack, +) -> ValueId { + let instruction = Instruction::MakeArray { elements, typ }; + let call_stack = call_stack.clone(); + dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() } fn make_constant_slice( dfg: &mut DataFlowGraph, results: Vec, typ: Type, + block: BasicBlockId, + call_stack: &CallStack, ) -> (ValueId, ValueId) { let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); let typ = Type::Slice(Arc::new(vec![typ])); let length = FieldElement::from(result_constants.len() as u128); - (dfg.make_constant(length, Type::length_type()), dfg.make_array(result_constants.into(), typ)) + let length = dfg.make_constant(length, Type::length_type()); + + let slice = make_array(dfg, result_constants.into(), typ, block, call_stack); + (length, slice) } /// Returns a slice (represented by a tuple (len, slice)) of constants corresponding to the limbs of the radix decomposition. @@ -611,6 +654,8 @@ fn constant_to_radix( radix: u32, limb_count: u32, dfg: &mut DataFlowGraph, + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { let bit_size = u32::BITS - (radix - 1).leading_zeros(); let radix_big = BigUint::from(radix); @@ -631,7 +676,13 @@ fn constant_to_radix( if endian == Endian::Big { limbs.reverse(); } - let result_array = make_constant_array(dfg, limbs, Type::unsigned(bit_size)); + let result_array = make_constant_array( + dfg, + limbs.into_iter(), + Type::unsigned(bit_size), + block, + call_stack, + ); SimplifyResult::SimplifiedTo(result_array) } } @@ -656,6 +707,8 @@ fn simplify_hash( dfg: &mut DataFlowGraph, arguments: &[ValueId], hash_function: fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { match dfg.get_array_constant(arguments[0]) { Some((input, _)) if array_is_constant(dfg, &input) => { @@ -664,9 +717,10 @@ fn simplify_hash( let hash = hash_function(&input_bytes) .expect("Rust solvable black box function should not fail"); - let hash_values = vecmap(hash, |byte| FieldElement::from_be_bytes_reduce(&[byte])); + let hash_values = hash.iter().map(|byte| FieldElement::from_be_bytes_reduce(&[*byte])); - let result_array = make_constant_array(dfg, hash_values, Type::unsigned(8)); + let result_array = + make_constant_array(dfg, hash_values, Type::unsigned(8), block, call_stack); SimplifyResult::SimplifiedTo(result_array) } _ => SimplifyResult::None, @@ -725,6 +779,8 @@ fn simplify_derive_generators( dfg: &mut DataFlowGraph, arguments: &[ValueId], num_generators: u32, + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { if arguments.len() == 2 { let domain_separator_string = dfg.get_array_constant(arguments[0]); @@ -754,8 +810,8 @@ fn simplify_derive_generators( results.push(is_infinite); } let len = results.len(); - let result = - dfg.make_array(results.into(), Type::Array(vec![Type::field()].into(), len)); + let typ = Type::Array(vec![Type::field()].into(), len); + let result = make_array(dfg, results.into(), typ, block, call_stack); SimplifyResult::SimplifiedTo(result) } else { SimplifyResult::None diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 3881646d5e4..4f2a31e2fb0 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -1,8 +1,13 @@ +use std::sync::Arc; + use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement}; -use iter_extended::vecmap; use crate::ssa::ir::{ - dfg::DataFlowGraph, instruction::SimplifyResult, types::Type, value::ValueId, + basic_block::BasicBlockId, + dfg::{CallStack, DataFlowGraph}, + instruction::{Instruction, SimplifyResult}, + types::Type, + value::ValueId, }; use super::{array_is_constant, make_constant_array, to_u8_vec}; @@ -11,6 +16,8 @@ pub(super) fn simplify_ec_add( dfg: &mut DataFlowGraph, solver: impl BlackBoxFunctionSolver, arguments: &[ValueId], + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { match ( dfg.get_numeric_constant(arguments[0]), @@ -39,13 +46,76 @@ pub(super) fn simplify_ec_add( return SimplifyResult::None; }; - let result_array = make_constant_array( - dfg, - vec![result_x, result_y, result_is_infinity], - Type::field(), - ); + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); - SimplifyResult::SimplifiedTo(result_array) + let typ = Type::Array(Arc::new(vec![Type::field()]), 3); + + let elements = im::vector![result_x, result_y, result_is_infinity]; + let instruction = Instruction::MakeArray { elements, typ }; + let result_array = + dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone()); + + SimplifyResult::SimplifiedTo(result_array.first()) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_msm( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], + block: BasicBlockId, + call_stack: &CallStack, +) -> SimplifyResult { + // TODO: Handle MSMs where a subset of the terms are constant. + match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) { + (Some((points, _)), Some((scalars, _))) => { + let Some(points) = points + .into_iter() + .map(|id| dfg.get_numeric_constant(id)) + .collect::>>() + else { + return SimplifyResult::None; + }; + + let Some(scalars) = scalars + .into_iter() + .map(|id| dfg.get_numeric_constant(id)) + .collect::>>() + else { + return SimplifyResult::None; + }; + + let mut scalars_lo = Vec::new(); + let mut scalars_hi = Vec::new(); + for (i, scalar) in scalars.into_iter().enumerate() { + if i % 2 == 0 { + scalars_lo.push(scalar); + } else { + scalars_hi.push(scalar); + } + } + + let Ok((result_x, result_y, result_is_infinity)) = + solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi) + else { + return SimplifyResult::None; + }; + + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + + let elements = im::vector![result_x, result_y, result_is_infinity]; + let typ = Type::Array(Arc::new(vec![Type::field()]), 3); + let instruction = Instruction::MakeArray { elements, typ }; + let result_array = + dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone()); + + SimplifyResult::SimplifiedTo(result_array.first()) } _ => SimplifyResult::None, } @@ -55,6 +125,8 @@ pub(super) fn simplify_poseidon2_permutation( dfg: &mut DataFlowGraph, solver: impl BlackBoxFunctionSolver, arguments: &[ValueId], + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { match (dfg.get_array_constant(arguments[0]), dfg.get_numeric_constant(arguments[1])) { (Some((state, _)), Some(state_length)) if array_is_constant(dfg, &state) => { @@ -74,7 +146,9 @@ pub(super) fn simplify_poseidon2_permutation( return SimplifyResult::None; }; - let result_array = make_constant_array(dfg, new_state, Type::field()); + let new_state = new_state.into_iter(); + let typ = Type::field(); + let result_array = make_constant_array(dfg, new_state, typ, block, call_stack); SimplifyResult::SimplifiedTo(result_array) } @@ -119,6 +193,8 @@ pub(super) fn simplify_hash( dfg: &mut DataFlowGraph, arguments: &[ValueId], hash_function: fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, + block: BasicBlockId, + call_stack: &CallStack, ) -> SimplifyResult { match dfg.get_array_constant(arguments[0]) { Some((input, _)) if array_is_constant(dfg, &input) => { @@ -127,9 +203,10 @@ pub(super) fn simplify_hash( let hash = hash_function(&input_bytes) .expect("Rust solvable black box function should not fail"); - let hash_values = vecmap(hash, |byte| FieldElement::from_be_bytes_reduce(&[byte])); + let hash_values = hash.iter().map(|byte| FieldElement::from_be_bytes_reduce(&[*byte])); - let result_array = make_constant_array(dfg, hash_values, Type::unsigned(8)); + let u8_type = Type::unsigned(8); + let result_array = make_constant_array(dfg, hash_values, u8_type, block, call_stack); SimplifyResult::SimplifiedTo(result_array) } _ => SimplifyResult::None, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/mod.rs similarity index 100% rename from noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/mod.rs diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/post_order.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/post_order.rs index 94ff96ba1d7..398ce887b96 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/post_order.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/post_order.rs @@ -22,7 +22,7 @@ impl PostOrder { } impl PostOrder { - /// Allocate and compute a function's block post-order. Pos + /// Allocate and compute a function's block post-order. pub(crate) fn with_function(func: &Function) -> Self { PostOrder(Self::compute_post_order(func)) } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 3bbe14f866a..c44e7d8a388 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -69,17 +69,6 @@ fn value(function: &Function, id: ValueId) -> String { } Value::Function(id) => id.to_string(), Value::Intrinsic(intrinsic) => intrinsic.to_string(), - Value::Array { array, typ } => { - let elements = vecmap(array, |element| value(function, *element)); - let element_types = &typ.clone().element_types(); - let element_types_str = - element_types.iter().map(|typ| typ.to_string()).collect::>().join(", "); - if element_types.len() == 1 { - format!("[{}] of {}", elements.join(", "), element_types_str) - } else { - format!("[{}] of ({})", elements.join(", "), element_types_str) - } - } Value::Param { .. } | Value::Instruction { .. } | Value::ForeignFunction(_) => { id.to_string() } @@ -230,6 +219,18 @@ fn display_instruction_inner( "if {then_condition} then {then_value} else if {else_condition} then {else_value}" ) } + Instruction::MakeArray { elements, typ } => { + write!(f, "make_array [")?; + + for (i, element) in elements.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", show(*element))?; + } + + writeln!(f, "] : {typ}") + } } } @@ -253,13 +254,9 @@ pub(crate) fn try_to_extract_string_from_error_payload( (is_string_type && (values.len() == 1)) .then_some(()) .and_then(|()| { - let Value::Array { array: values, .. } = &dfg[values[0]] else { - return None; - }; - let fields: Option> = - values.iter().map(|value_id| dfg.get_numeric_constant(*value_id)).collect(); - - fields + let (values, _) = &dfg.get_array_constant(values[0])?; + let values = values.iter().map(|value_id| dfg.get_numeric_constant(*value_id)); + values.collect::>>() }) .map(|fields| { fields diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/value.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/value.rs index 795d45c75e9..ef494200308 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/value.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/value.rs @@ -36,9 +36,6 @@ pub(crate) enum Value { /// This Value originates from a numeric constant NumericConstant { constant: FieldElement, typ: Type }, - /// Represents a constant array value - Array { array: im::Vector, typ: Type }, - /// This Value refers to a function in the IR. /// Functions always have the type Type::Function. /// If the argument or return types are needed, users should retrieve @@ -63,7 +60,6 @@ impl Value { Value::Instruction { typ, .. } => typ, Value::Param { typ, .. } => typ, Value::NumericConstant { typ, .. } => typ, - Value::Array { typ, .. } => typ, Value::Function { .. } => &Type::Function, Value::Intrinsic { .. } => &Type::Function, Value::ForeignFunction { .. } => &Type::Function, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/array_set.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/array_set.rs index 865a1e31eb3..96de22600a4 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/array_set.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/array_set.rs @@ -199,29 +199,31 @@ mod tests { let src = " brillig(inline) fn main f0 { b0(): - v1 = allocate -> &mut [Field; 5] - store [[Field 0, Field 0, Field 0, Field 0, Field 0] of Field, [Field 0, Field 0, Field 0, Field 0, Field 0] of Field] of [Field; 5] at v1 - v6 = allocate -> &mut [Field; 5] - store [[Field 0, Field 0, Field 0, Field 0, Field 0] of Field, [Field 0, Field 0, Field 0, Field 0, Field 0] of Field] of [Field; 5] at v6 + v2 = make_array [Field 0, Field 0, Field 0, Field 0, Field 0] : [Field; 5] + v3 = make_array [v2, v2] : [[Field; 5]; 2] + v4 = allocate -> &mut [Field; 5] + store v3 at v4 + v5 = allocate -> &mut [Field; 5] + store v3 at v5 jmp b1(u32 0) b1(v0: u32): - v12 = lt v0, u32 5 - jmpif v12 then: b3, else: b2 + v8 = lt v0, u32 5 + jmpif v8 then: b3, else: b2 b3(): - v13 = eq v0, u32 5 - jmpif v13 then: b4, else: b5 + v9 = eq v0, u32 5 + jmpif v9 then: b4, else: b5 b4(): - v14 = load v1 -> [[Field; 5]; 2] - store v14 at v6 + v10 = load v4 -> [[Field; 5]; 2] + store v10 at v5 jmp b5() b5(): - v15 = load v1 -> [[Field; 5]; 2] - v16 = array_get v15, index Field 0 -> [Field; 5] - v18 = array_set v16, index v0, value Field 20 - v19 = array_set v15, index v0, value v18 - store v19 at v1 - v21 = add v0, u32 1 - jmp b1(v21) + v11 = load v4 -> [[Field; 5]; 2] + v12 = array_get v11, index Field 0 -> [Field; 5] + v14 = array_set v12, index v0, value Field 20 + v15 = array_set v11, index v0, value v14 + store v15 at v4 + v17 = add v0, u32 1 + jmp b1(v17) b2(): return } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index b5ab688cb2c..9f55e69868c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -440,7 +440,8 @@ mod test { acir(inline) fn main f0 { b0(v0: Field): v2 = add v0, Field 1 - return [v2] of Field + v3 = make_array [v2] : [Field; 1] + return v3 } "; let ssa = Ssa::from_str(src).unwrap(); @@ -551,16 +552,17 @@ mod test { // the other is not. If one is removed, it is possible e.g. v4 is replaced with v2 which // is disabled (only gets from index 0) and thus returns the wrong result. let src = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u64): - enable_side_effects v0 - v5 = array_get [Field 0, Field 1] of Field, index v1 -> Field - v6 = not v0 - enable_side_effects v6 - v8 = array_get [Field 0, Field 1] of Field, index v1 -> Field - return - } - "; + acir(inline) fn main f0 { + b0(v0: u1, v1: u64): + enable_side_effects v0 + v4 = make_array [Field 0, Field 1] : [Field; 2] + v5 = array_get v4, index v1 -> Field + v6 = not v0 + enable_side_effects v6 + v7 = array_get v4, index v1 -> Field + return + } + "; let ssa = Ssa::from_str(src).unwrap(); // Expected output is unchanged @@ -638,12 +640,12 @@ mod test { let zero = builder.numeric_constant(0u128, Type::unsigned(64)); let typ = Type::Array(Arc::new(vec![Type::unsigned(64)]), 25); - let array_contents = vec![ + let array_contents = im::vector![ v0, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, zero, ]; - let array1 = builder.array_constant(array_contents.clone().into(), typ.clone()); - let array2 = builder.array_constant(array_contents.into(), typ.clone()); + let array1 = builder.insert_make_array(array_contents.clone(), typ.clone()); + let array2 = builder.insert_make_array(array_contents, typ.clone()); assert_eq!(array1, array2, "arrays were assigned different value ids"); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs index f28b076a5f9..666a8e32246 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -192,29 +192,11 @@ impl Context { }); } - /// Inspects a value recursively (as it could be an array) and marks all comprised instruction - /// results as used. + /// Inspects a value and marks all instruction results as used. fn mark_used_instruction_results(&mut self, dfg: &DataFlowGraph, value_id: ValueId) { let value_id = dfg.resolve(value_id); - match &dfg[value_id] { - Value::Instruction { .. } => { - self.used_values.insert(value_id); - } - Value::Array { array, .. } => { - self.used_values.insert(value_id); - for elem in array { - self.mark_used_instruction_results(dfg, *elem); - } - } - Value::Param { .. } => { - self.used_values.insert(value_id); - } - Value::NumericConstant { .. } => { - self.used_values.insert(value_id); - } - _ => { - // Does not comprise of any instruction results - } + if matches!(&dfg[value_id], Value::Instruction { .. } | Value::Param { .. }) { + self.used_values.insert(value_id); } } @@ -740,10 +722,11 @@ mod test { fn keep_inc_rc_on_borrowed_array_store() { // acir(inline) fn main f0 { // b0(): + // v1 = make_array [u32 0, u32 0] // v2 = allocate - // inc_rc [u32 0, u32 0] - // store [u32 0, u32 0] at v2 - // inc_rc [u32 0, u32 0] + // inc_rc v1 + // store v1 at v2 + // inc_rc v1 // jmp b1() // b1(): // v3 = load v2 @@ -756,11 +739,11 @@ mod test { let mut builder = FunctionBuilder::new("main".into(), main_id); let zero = builder.numeric_constant(0u128, Type::unsigned(32)); let array_type = Type::Array(Arc::new(vec![Type::unsigned(32)]), 2); - let array = builder.array_constant(vector![zero, zero], array_type.clone()); + let v1 = builder.insert_make_array(vector![zero, zero], array_type.clone()); let v2 = builder.insert_allocate(array_type.clone()); - builder.increment_array_reference_count(array); - builder.insert_store(v2, array); - builder.increment_array_reference_count(array); + builder.increment_array_reference_count(v1); + builder.insert_store(v2, v1); + builder.increment_array_reference_count(v1); let b1 = builder.insert_block(); builder.terminate_with_jmp(b1, vec![]); @@ -775,14 +758,14 @@ mod test { let main = ssa.main(); // The instruction count never includes the terminator instruction - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 4); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 5); assert_eq!(main.dfg[b1].instructions().len(), 2); // We expect the output to be unchanged let ssa = ssa.dead_instruction_elimination(); let main = ssa.main(); - assert_eq!(main.dfg[main.entry_block()].instructions().len(), 4); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 5); assert_eq!(main.dfg[b1].instructions().len(), 2); } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index db2d96aac81..a2b8e20d20f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -826,8 +826,8 @@ impl<'f> Context<'f> { } Value::Intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul)) => { let points_array_idx = if matches!( - self.inserter.function.dfg[arguments[0]], - Value::Array { .. } + self.inserter.function.dfg.type_of_value(arguments[0]), + Type::Array { .. } ) { 0 } else { @@ -835,15 +835,15 @@ impl<'f> Context<'f> { // which means the array is the second argument 1 }; - let (array_with_predicate, array_typ) = self - .apply_predicate_to_msm_argument( - arguments[points_array_idx], - condition, - call_stack.clone(), - ); - - arguments[points_array_idx] = - self.inserter.function.dfg.make_array(array_with_predicate, array_typ); + let (elements, typ) = self.apply_predicate_to_msm_argument( + arguments[points_array_idx], + condition, + call_stack.clone(), + ); + + let instruction = Instruction::MakeArray { elements, typ }; + let array = self.insert_instruction(instruction, call_stack); + arguments[points_array_idx] = array; Instruction::Call { func, arguments } } _ => Instruction::Call { func, arguments }, @@ -866,7 +866,7 @@ impl<'f> Context<'f> { ) -> (im::Vector, Type) { let array_typ; let mut array_with_predicate = im::Vector::new(); - if let Value::Array { array, typ } = &self.inserter.function.dfg[argument] { + if let Some((array, typ)) = &self.inserter.function.dfg.get_array_constant(argument) { array_typ = typ.clone(); for (i, value) in array.clone().iter().enumerate() { if i % 3 == 2 { @@ -1259,56 +1259,41 @@ mod test { // }; // } // - // // Translates to the following before the flattening pass: - // fn main f2 { - // b0(v0: u1): - // jmpif v0 then: b1, else: b2 - // b1(): - // v2 = allocate - // store Field 0 at v2 - // v4 = load v2 - // jmp b2() - // b2(): - // return - // } + // Translates to the following before the flattening pass: + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + jmpif v0 then: b1, else: b2 + b1(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = load v1 -> Field + jmp b2() + b2(): + return + }"; // The bug is that the flattening pass previously inserted a load // before the first store to allocate, which loaded an uninitialized value. // In this test we assert the ordering is strictly Allocate then Store then Load. - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - - let v0 = builder.add_parameter(Type::bool()); - builder.terminate_with_jmpif(v0, b1, b2); - - builder.switch_to_block(b1); - let v2 = builder.insert_allocate(Type::field()); - let zero = builder.field_constant(0u128); - builder.insert_store(v2, zero); - let _v4 = builder.insert_load(v2, Type::field()); - builder.terminate_with_jmp(b2, vec![]); - - builder.switch_to_block(b2); - builder.terminate_with_return(vec![]); - - let ssa = builder.finish().flatten_cfg(); - let main = ssa.main(); + let ssa = Ssa::from_str(src).unwrap(); + let flattened_ssa = ssa.flatten_cfg(); // Now assert that there is not a load between the allocate and its first store // The Expected IR is: - // - // fn main f2 { - // b0(v0: u1): - // enable_side_effects v0 - // v6 = allocate - // store Field 0 at v6 - // v7 = load v6 - // v8 = not v0 - // enable_side_effects u1 1 - // return - // } + let expected = " + acir(inline) fn main f0 { + b0(v0: u1): + enable_side_effects v0 + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = load v1 -> Field + v4 = not v0 + enable_side_effects u1 1 + return + } + "; + + let main = flattened_ssa.main(); let instructions = main.dfg[main.entry_block()].instructions(); let find_instruction = |predicate: fn(&Instruction) -> bool| { @@ -1321,6 +1306,8 @@ mod test { assert!(allocate_index < store_index); assert!(store_index < load_index); + + assert_normalized_ssa_equals(flattened_ssa, expected); } /// Work backwards from an instruction to find all the constant values @@ -1416,29 +1403,22 @@ mod test { // } let main_id = Id::test_new(1); let mut builder = FunctionBuilder::new("main".into(), main_id); - builder.insert_block(); // b0 let b1 = builder.insert_block(); let b2 = builder.insert_block(); let b3 = builder.insert_block(); - let element_type = Arc::new(vec![Type::unsigned(8)]); let array_type = Type::Array(element_type.clone(), 2); let array = builder.add_parameter(array_type); - let zero = builder.numeric_constant(0_u128, Type::unsigned(8)); - let v5 = builder.insert_array_get(array, zero, Type::unsigned(8)); let v6 = builder.insert_cast(v5, Type::unsigned(32)); let i_two = builder.numeric_constant(2_u128, Type::unsigned(32)); let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two); let v9 = builder.insert_cast(v8, Type::bool()); - let v10 = builder.insert_allocate(Type::field()); builder.insert_store(v10, zero); - builder.terminate_with_jmpif(v9, b1, b2); - builder.switch_to_block(b1); let one = builder.field_constant(1_u128); let v5b = builder.insert_cast(v5, Type::field()); @@ -1446,21 +1426,17 @@ mod test { let v14 = builder.insert_cast(v13, Type::unsigned(8)); builder.insert_store(v10, v14); builder.terminate_with_jmp(b3, vec![]); - builder.switch_to_block(b2); builder.insert_store(v10, zero); builder.terminate_with_jmp(b3, vec![]); - builder.switch_to_block(b3); let v_true = builder.numeric_constant(true, Type::bool()); let v12 = builder.insert_binary(v9, BinaryOp::Eq, v_true); builder.insert_constrain(v12, v_true, None); builder.terminate_with_return(vec![]); - let ssa = builder.finish(); let flattened_ssa = ssa.flatten_cfg(); let main = flattened_ssa.main(); - // Now assert that there is not an always-false constraint after flattening: let mut constrain_count = 0; for instruction in main.dfg[main.entry_block()].instructions() { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs index ef208588718..ddc8b0bfe6b 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs @@ -26,25 +26,23 @@ impl<'a> SliceCapacityTracker<'a> { ) { match instruction { Instruction::ArrayGet { array, .. } => { - let array_typ = self.dfg.type_of_value(*array); - let array_value = &self.dfg[*array]; - if matches!(array_value, Value::Array { .. }) && array_typ.contains_slice_element() - { - // Initial insertion into the slice sizes map - // Any other insertions should only occur if the value is already - // a part of the map. - self.compute_slice_capacity(*array, slice_sizes); + if let Some((_, array_type)) = self.dfg.get_array_constant(*array) { + if array_type.contains_slice_element() { + // Initial insertion into the slice sizes map + // Any other insertions should only occur if the value is already + // a part of the map. + self.compute_slice_capacity(*array, slice_sizes); + } } } Instruction::ArraySet { array, value, .. } => { - let array_typ = self.dfg.type_of_value(*array); - let array_value = &self.dfg[*array]; - if matches!(array_value, Value::Array { .. }) && array_typ.contains_slice_element() - { - // Initial insertion into the slice sizes map - // Any other insertions should only occur if the value is already - // a part of the map. - self.compute_slice_capacity(*array, slice_sizes); + if let Some((_, array_type)) = self.dfg.get_array_constant(*array) { + if array_type.contains_slice_element() { + // Initial insertion into the slice sizes map + // Any other insertions should only occur if the value is already + // a part of the map. + self.compute_slice_capacity(*array, slice_sizes); + } } let value_typ = self.dfg.type_of_value(*value); @@ -161,7 +159,7 @@ impl<'a> SliceCapacityTracker<'a> { array_id: ValueId, slice_sizes: &mut HashMap, ) { - if let Value::Array { array, typ } = &self.dfg[array_id] { + if let Some((array, typ)) = self.dfg.get_array_constant(array_id) { // Compiler sanity check assert!(!typ.is_nested_slice(), "ICE: Nested slices are not allowed and should not have reached the flattening pass of SSA"); if let Type::Slice(_) = typ { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 75ee57dd4fa..bee58278aa8 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -202,7 +202,9 @@ impl<'a> ValueMerger<'a> { } } - self.dfg.make_array(merged, typ) + let instruction = Instruction::MakeArray { elements: merged, typ }; + let call_stack = self.call_stack.clone(); + self.dfg.insert_instruction_and_results(instruction, self.block, None, call_stack).first() } fn merge_slice_values( @@ -276,7 +278,9 @@ impl<'a> ValueMerger<'a> { } } - self.dfg.make_array(merged, typ) + let instruction = Instruction::MakeArray { elements: merged, typ }; + let call_stack = self.call_stack.clone(); + self.dfg.insert_instruction_and_results(instruction, self.block, None, call_stack).first() } /// Construct a dummy value to be attached to the smaller of two slices being merged. @@ -296,7 +300,11 @@ impl<'a> ValueMerger<'a> { array.push_back(self.make_slice_dummy_data(typ)); } } - self.dfg.make_array(array, typ.clone()) + let instruction = Instruction::MakeArray { elements: array, typ: typ.clone() }; + let call_stack = self.call_stack.clone(); + self.dfg + .insert_instruction_and_results(instruction, self.block, None, call_stack) + .first() } Type::Slice(_) => { // TODO(#3188): Need to update flattening to use true user facing length of slices diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 2eb0f2eda0f..f91487fd73e 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -476,10 +476,6 @@ impl<'function> PerFunctionContext<'function> { Value::ForeignFunction(function) => { self.context.builder.import_foreign_function(function) } - Value::Array { array, typ } => { - let elements = array.iter().map(|value| self.translate_value(*value)).collect(); - self.context.builder.array_constant(elements, typ.clone()) - } }; self.values.insert(id, new_value); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index a052abc5e16..0690dbbf204 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -171,9 +171,7 @@ impl<'f> PerFunctionContext<'f> { let block_params = self.inserter.function.dfg.block_parameters(*block_id); per_func_block_params.extend(block_params.iter()); let terminator = self.inserter.function.dfg[*block_id].unwrap_terminator(); - terminator.for_each_value(|value| { - self.recursively_add_values(value, &mut all_terminator_values); - }); + terminator.for_each_value(|value| all_terminator_values.insert(value)); } // If we never load from an address within a function we can remove all stores to that address. @@ -268,15 +266,6 @@ impl<'f> PerFunctionContext<'f> { .collect() } - fn recursively_add_values(&self, value: ValueId, set: &mut HashSet) { - set.insert(value); - if let Some((elements, _)) = self.inserter.function.dfg.get_array_constant(value) { - for array_element in elements { - self.recursively_add_values(array_element, set); - } - } - } - /// The value of each reference at the start of the given block is the unification /// of the value of the same reference at the end of its predecessor blocks. fn find_starting_references(&mut self, block: BasicBlockId) -> Block { @@ -426,8 +415,6 @@ impl<'f> PerFunctionContext<'f> { let address = self.inserter.function.dfg.resolve(*address); let value = self.inserter.function.dfg.resolve(*value); - self.check_array_aliasing(references, value); - // If there was another store to this instruction without any (unremoved) loads or // function calls in-between, we can remove the previous store. if let Some(last_store) = references.last_stores.get(&address) { @@ -512,24 +499,22 @@ impl<'f> PerFunctionContext<'f> { } self.mark_all_unknown(arguments, references); } - _ => (), - } - } - - /// If `array` is an array constant that contains reference types, then insert each element - /// as a potential alias to the array itself. - fn check_array_aliasing(&self, references: &mut Block, array: ValueId) { - if let Some((elements, typ)) = self.inserter.function.dfg.get_array_constant(array) { - if Self::contains_references(&typ) { - // TODO: Check if type directly holds references or holds arrays that hold references - let expr = Expression::ArrayElement(Box::new(Expression::Other(array))); - references.expressions.insert(array, expr.clone()); - let aliases = references.aliases.entry(expr).or_default(); - - for element in elements { - aliases.insert(element); + Instruction::MakeArray { elements, typ } => { + // If `array` is an array constant that contains reference types, then insert each element + // as a potential alias to the array itself. + if Self::contains_references(typ) { + let array = self.inserter.function.dfg.instruction_results(instruction)[0]; + + let expr = Expression::ArrayElement(Box::new(Expression::Other(array))); + references.expressions.insert(array, expr.clone()); + let aliases = references.aliases.entry(expr).or_default(); + + for element in elements { + aliases.insert(*element); + } } } + _ => (), } } @@ -634,10 +619,11 @@ mod tests { // fn func() { // b0(): // v0 = allocate - // store [Field 1, Field 2] in v0 - // v1 = load v0 - // v2 = array_get v1, index 1 - // return v2 + // v1 = make_array [Field 1, Field 2] + // store v1 in v0 + // v2 = load v0 + // v3 = array_get v2, index 1 + // return v3 // } let func_id = Id::test_new(0); @@ -648,12 +634,12 @@ mod tests { let element_type = Arc::new(vec![Type::field()]); let array_type = Type::Array(element_type, 2); - let array = builder.array_constant(vector![one, two], array_type.clone()); + let v1 = builder.insert_make_array(vector![one, two], array_type.clone()); - builder.insert_store(v0, array); - let v1 = builder.insert_load(v0, array_type); - let v2 = builder.insert_array_get(v1, one, Type::field()); - builder.terminate_with_return(vec![v2]); + builder.insert_store(v0, v1); + let v2 = builder.insert_load(v0, array_type); + let v3 = builder.insert_array_get(v2, one, Type::field()); + builder.terminate_with_return(vec![v3]); let ssa = builder.finish().mem2reg().fold_constants(); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 098f62bceba..10e86c6601a 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -35,13 +35,14 @@ pub(crate) fn assert_normalized_ssa_equals(mut ssa: super::Ssa, expected: &str) panic!("`expected` argument of `assert_ssa_equals` is not valid SSA:\n{:?}", err); } - use crate::{ssa::Ssa, trim_leading_whitespace_from_lines}; + use crate::{ssa::Ssa, trim_comments_from_lines, trim_leading_whitespace_from_lines}; ssa.normalize_ids(); let ssa = ssa.to_string(); let ssa = trim_leading_whitespace_from_lines(&ssa); let expected = trim_leading_whitespace_from_lines(expected); + let expected = trim_comments_from_lines(&expected); if ssa != expected { println!("Expected:\n~~~\n{expected}\n~~~\nGot:\n~~~\n{ssa}\n~~~"); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs index 6914bf87c5d..a5b60fb5fcd 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs @@ -179,19 +179,6 @@ impl IdMaps { Value::NumericConstant { constant, typ } => { new_function.dfg.make_constant(*constant, typ.clone()) } - Value::Array { array, typ } => { - if let Some(value) = self.values.get(&old_value) { - return *value; - } - - let array = array - .iter() - .map(|value| self.map_value(new_function, old_function, *value)) - .collect(); - let new_value = new_function.dfg.make_array(array, typ.clone()); - self.values.insert(old_value, new_value); - new_value - } Value::Intrinsic(intrinsic) => new_function.dfg.import_intrinsic(*intrinsic), Value::ForeignFunction(name) => new_function.dfg.import_foreign_function(name), } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/rc.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/rc.rs index c3606ac4311..ffe4ada39b7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/rc.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/rc.rs @@ -197,7 +197,8 @@ mod test { // inc_rc v0 // inc_rc v0 // dec_rc v0 - // return [v0] + // v1 = make_array [v0] + // return v1 // } let main_id = Id::test_new(0); let mut builder = FunctionBuilder::new("foo".into(), main_id); @@ -211,8 +212,8 @@ mod test { builder.insert_dec_rc(v0); let outer_array_type = Type::Array(Arc::new(vec![inner_array_type]), 1); - let array = builder.array_constant(vec![v0].into(), outer_array_type); - builder.terminate_with_return(vec![array]); + let v1 = builder.insert_make_array(vec![v0].into(), outer_array_type); + builder.terminate_with_return(vec![v1]); let ssa = builder.finish().remove_paired_rc(); let main = ssa.main(); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs index 012f6e6b27d..0517f9ef89f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs @@ -145,7 +145,8 @@ impl Context { | RangeCheck { .. } | IfElse { .. } | IncrementRc { .. } - | DecrementRc { .. } => false, + | DecrementRc { .. } + | MakeArray { .. } => false, EnableSideEffectsIf { .. } | ArrayGet { .. } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 267dc6a3c20..89f1b2b2d7d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -7,16 +7,20 @@ //! b. If we have previously modified any of the blocks in the loop, //! restart from step 1 to refresh the context. //! c. If not, try to unroll the loop. If successful, remember the modified -//! blocks. If unsuccessfully either error if the abort_on_error flag is set, +//! blocks. If unsuccessful either error if the abort_on_error flag is set, //! or otherwise remember that the loop failed to unroll and leave it unmodified. //! //! Note that this pass also often creates superfluous jmp instructions in the -//! program that will need to be removed by a later simplify cfg pass. -//! Note also that unrolling is skipped for Brillig runtime and as a result -//! we remove reference count instructions because they are only used by Brillig bytecode +//! program that will need to be removed by a later simplify CFG pass. +//! +//! Note also that unrolling is skipped for Brillig runtime, unless the loops are deemed +//! sufficiently small that inlining can be done without increasing the bytecode. +//! +//! When unrolling ACIR code, we remove reference count instructions because they are +//! only used by Brillig bytecode. use std::collections::HashSet; -use acvm::acir::AcirField; +use acvm::{acir::AcirField, FieldElement}; use crate::{ errors::RuntimeError, @@ -26,9 +30,9 @@ use crate::{ cfg::ControlFlowGraph, dfg::{CallStack, DataFlowGraph}, dom::DominatorTree, - function::{Function, RuntimeType}, - function_inserter::FunctionInserter, - instruction::{Instruction, TerminatorInstruction}, + function::Function, + function_inserter::{ArrayCache, FunctionInserter}, + instruction::{Binary, BinaryOp, Instruction, InstructionId, TerminatorInstruction}, post_order::PostOrder, value::ValueId, }, @@ -42,16 +46,9 @@ impl Ssa { /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. #[tracing::instrument(level = "trace", skip(ssa))] pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { - let acir_functions = ssa.functions.iter_mut().filter(|(_, func)| { - // Loop unrolling in brillig can lead to a code explosion currently. This can - // also be true for ACIR, but we have no alternative to unrolling in ACIR. - // Brillig also generally prefers smaller code rather than faster code. - !matches!(func.runtime(), RuntimeType::Brillig(_)) - }); - - for (_, function) in acir_functions { + for (_, function) in ssa.functions.iter_mut() { // Try to unroll loops first: - let mut unroll_errors = function.try_to_unroll_loops(); + let mut unroll_errors = function.try_unroll_loops(); // Keep unrolling until no more errors are found while !unroll_errors.is_empty() { @@ -66,21 +63,24 @@ impl Ssa { function.mem2reg(); // Unroll again - unroll_errors = function.try_to_unroll_loops(); + unroll_errors = function.try_unroll_loops(); // If we didn't manage to unroll any more loops, exit if unroll_errors.len() >= prev_unroll_err_count { return Err(unroll_errors.swap_remove(0)); } } } - Ok(ssa) } } impl Function { - fn try_to_unroll_loops(&mut self) -> Vec { - find_all_loops(self).unroll_each_loop(self) + // Loop unrolling in brillig can lead to a code explosion currently. + // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. + // Brillig also generally prefers smaller code rather than faster code, + // so we only attempt to unroll small loops, which we decide on a case-by-case basis. + fn try_unroll_loops(&mut self) -> Vec { + Loops::find_all(self).unroll_each(self) } } @@ -94,7 +94,7 @@ struct Loop { back_edge_start: BasicBlockId, /// All the blocks contained within the loop, including `header` and `back_edge_start`. - pub(crate) blocks: HashSet, + blocks: HashSet, } struct Loops { @@ -107,60 +107,88 @@ struct Loops { cfg: ControlFlowGraph, } -/// Find a loop in the program by finding a node that dominates any predecessor node. -/// The edge where this happens will be the back-edge of the loop. -fn find_all_loops(function: &Function) -> Loops { - let cfg = ControlFlowGraph::with_function(function); - let post_order = PostOrder::with_function(function); - let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); - - let mut loops = vec![]; - - for (block, _) in function.dfg.basic_blocks_iter() { - // These reachable checks wouldn't be needed if we only iterated over reachable blocks - if dom_tree.is_reachable(block) { - for predecessor in cfg.predecessors(block) { - if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) { - // predecessor -> block is the back-edge of a loop - loops.push(find_blocks_in_loop(block, predecessor, &cfg)); +impl Loops { + /// Find a loop in the program by finding a node that dominates any predecessor node. + /// The edge where this happens will be the back-edge of the loop. + /// + /// For example consider the following SSA of a basic loop: + /// ```text + /// main(): + /// v0 = ... start ... + /// v1 = ... end ... + /// jmp loop_entry(v0) + /// loop_entry(i: Field): + /// v2 = lt i v1 + /// jmpif v2, then: loop_body, else: loop_end + /// loop_body(): + /// v3 = ... body ... + /// v4 = add 1, i + /// jmp loop_entry(v4) + /// loop_end(): + /// ``` + /// + /// The CFG will look something like this: + /// ```text + /// main + /// ↓ + /// loop_entry ←---↰ + /// ↓ ↘ | + /// loop_end loop_body + /// ``` + /// `loop_entry` has two predecessors: `main` and `loop_body`, and it dominates `loop_body`. + fn find_all(function: &Function) -> Self { + let cfg = ControlFlowGraph::with_function(function); + let post_order = PostOrder::with_function(function); + let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); + + let mut loops = vec![]; + + for (block, _) in function.dfg.basic_blocks_iter() { + // These reachable checks wouldn't be needed if we only iterated over reachable blocks + if dom_tree.is_reachable(block) { + for predecessor in cfg.predecessors(block) { + // In the above example, we're looking for when `block` is `loop_entry` and `predecessor` is `loop_body`. + if dom_tree.is_reachable(predecessor) && dom_tree.dominates(block, predecessor) + { + // predecessor -> block is the back-edge of a loop + loops.push(Loop::find_blocks_in_loop(block, predecessor, &cfg)); + } } } } - } - // Sort loops by block size so that we unroll the larger, outer loops of nested loops first. - // This is needed because inner loops may use the induction variable from their outer loops in - // their loop range. - loops.sort_by_key(|loop_| loop_.blocks.len()); + // Sort loops by block size so that we unroll the larger, outer loops of nested loops first. + // This is needed because inner loops may use the induction variable from their outer loops in + // their loop range. We will start popping loops from the back. + loops.sort_by_key(|loop_| loop_.blocks.len()); - Loops { - failed_to_unroll: HashSet::new(), - yet_to_unroll: loops, - modified_blocks: HashSet::new(), - cfg, + Self { + failed_to_unroll: HashSet::new(), + yet_to_unroll: loops, + modified_blocks: HashSet::new(), + cfg, + } } -} -impl Loops { /// Unroll all loops within a given function. /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - fn unroll_each_loop(mut self, function: &mut Function) -> Vec { + fn unroll_each(mut self, function: &mut Function) -> Vec { let mut unroll_errors = vec![]; while let Some(next_loop) = self.yet_to_unroll.pop() { + if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) { + continue; + } // If we've previously modified a block in this loop we need to refresh the context. // This happens any time we have nested loops. if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { - let mut new_context = find_all_loops(function); - new_context.failed_to_unroll = self.failed_to_unroll; - return unroll_errors - .into_iter() - .chain(new_context.unroll_each_loop(function)) - .collect(); + let mut new_loops = Self::find_all(function); + new_loops.failed_to_unroll = self.failed_to_unroll; + return unroll_errors.into_iter().chain(new_loops.unroll_each(function)).collect(); } // Don't try to unroll the loop again if it is known to fail if !self.failed_to_unroll.contains(&next_loop.header) { - match unroll_loop(function, &self.cfg, &next_loop) { + match next_loop.unroll(function, &self.cfg) { Ok(_) => self.modified_blocks.extend(next_loop.blocks), Err(call_stack) => { self.failed_to_unroll.insert(next_loop.header); @@ -173,73 +201,522 @@ impl Loops { } } -/// Return each block that is in a loop starting in the given header block. -/// Expects back_edge_start -> header to be the back edge of the loop. -fn find_blocks_in_loop( - header: BasicBlockId, - back_edge_start: BasicBlockId, - cfg: &ControlFlowGraph, -) -> Loop { - let mut blocks = HashSet::new(); - blocks.insert(header); - - let mut insert = |block, stack: &mut Vec| { - if !blocks.contains(&block) { - blocks.insert(block); - stack.push(block); +impl Loop { + /// Return each block that is in a loop starting in the given header block. + /// Expects back_edge_start -> header to be the back edge of the loop. + fn find_blocks_in_loop( + header: BasicBlockId, + back_edge_start: BasicBlockId, + cfg: &ControlFlowGraph, + ) -> Self { + let mut blocks = HashSet::new(); + blocks.insert(header); + + let mut insert = |block, stack: &mut Vec| { + if !blocks.contains(&block) { + blocks.insert(block); + stack.push(block); + } + }; + + // Starting from the back edge of the loop, each predecessor of this block until + // the header is within the loop. + let mut stack = vec![]; + insert(back_edge_start, &mut stack); + + while let Some(block) = stack.pop() { + for predecessor in cfg.predecessors(block) { + insert(predecessor, &mut stack); + } } - }; - // Starting from the back edge of the loop, each predecessor of this block until - // the header is within the loop. - let mut stack = vec![]; - insert(back_edge_start, &mut stack); + Self { header, back_edge_start, blocks } + } + + /// Find the lower bound of the loop in the pre-header and return it + /// if it's a numeric constant, which it will be if the previous SSA + /// steps managed to inline it. + /// + /// Consider the following example of a `for i in 0..4` loop: + /// ```text + /// brillig(inline) fn main f0 { + /// b0(v0: u32): // Pre-header + /// ... + /// jmp b1(u32 0) // Lower-bound + /// b1(v1: u32): // Induction variable + /// v5 = lt v1, u32 4 + /// jmpif v5 then: b3, else: b2 + /// ``` + fn get_const_lower_bound( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result, CallStack> { + let pre_header = self.get_pre_header(function, cfg)?; + let jump_value = get_induction_variable(function, pre_header)?; + Ok(function.dfg.get_numeric_constant(jump_value)) + } - while let Some(block) = stack.pop() { - for predecessor in cfg.predecessors(block) { - insert(predecessor, &mut stack); + /// Find the upper bound of the loop in the loop header and return it + /// if it's a numeric constant, which it will be if the previous SSA + /// steps managed to inline it. + /// + /// Consider the following example of a `for i in 0..4` loop: + /// ```text + /// brillig(inline) fn main f0 { + /// b0(v0: u32): + /// ... + /// jmp b1(u32 0) + /// b1(v1: u32): // Loop header + /// v5 = lt v1, u32 4 // Upper bound + /// jmpif v5 then: b3, else: b2 + /// ``` + fn get_const_upper_bound(&self, function: &Function) -> Option { + let block = &function.dfg[self.header]; + let instructions = block.instructions(); + assert_eq!( + instructions.len(), + 1, + "The header should just compare the induction variable and jump" + ); + match &function.dfg[instructions[0]] { + Instruction::Binary(Binary { lhs: _, operator: BinaryOp::Lt, rhs }) => { + function.dfg.get_numeric_constant(*rhs) + } + Instruction::Binary(Binary { lhs: _, operator: BinaryOp::Eq, rhs }) => { + // `for i in 0..1` is turned into: + // b1(v0: u32): + // v12 = eq v0, u32 0 + // jmpif v12 then: b3, else: b2 + function.dfg.get_numeric_constant(*rhs).map(|c| c + FieldElement::one()) + } + other => panic!("Unexpected instruction in header: {other:?}"), } } - Loop { header, back_edge_start, blocks } -} + /// Get the lower and upper bounds of the loop if both are constant numeric values. + fn get_const_bounds( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result, CallStack> { + let Some(lower) = self.get_const_lower_bound(function, cfg)? else { + return Ok(None); + }; + let Some(upper) = self.get_const_upper_bound(function) else { + return Ok(None); + }; + Ok(Some((lower, upper))) + } + + /// Unroll a single loop in the function. + /// Returns Ok(()) if it succeeded, Err(callstack) if it failed, + /// where the callstack indicates the location of the instruction + /// that could not be processed, or empty if such information was + /// not available. + /// + /// Consider this example: + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// jmp loop_entry(v0) + /// loop_entry(i: Field): + /// v2 = lt i v1 + /// jmpif v2, then: loop_body, else: loop_end + /// ``` + /// + /// The first step is to unroll the header by recognizing that jump condition + /// is a constant, which means it will go to `loop_body`: + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// v2 = lt v0 v1 + /// // jmpif v2, then: loop_body, else: loop_end + /// jmp dest: loop_body + /// ``` + /// + /// Following that we unroll the loop body, which is the next source, replace + /// the induction variable with the new value created in the body, and have + /// another go at the header. + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// v2 = lt v0 v1 + /// v3 = ... body ... + /// v4 = add 1, 0 + /// jmp loop_entry(v4) + /// ``` + /// + /// At the end we reach a point where the condition evaluates to 0 and we jump to the end. + /// ```text + /// main(): + /// v0 = 0 + /// v1 = 2 + /// v2 = lt 0 + /// v3 = ... body ... + /// v4 = add 1, v0 + /// v5 = lt v4 v1 + /// v6 = ... body ... + /// v7 = add v4, 1 + /// v8 = lt v5 v1 + /// jmp loop_end + /// ``` + /// + /// When e.g. `v8 = lt v5 v1` cannot be evaluated to a constant, the loop signals by returning `Err` + /// that a few SSA passes are required to evaluate and simplify these values. + fn unroll(&self, function: &mut Function, cfg: &ControlFlowGraph) -> Result<(), CallStack> { + let mut unroll_into = self.get_pre_header(function, cfg)?; + let mut jump_value = get_induction_variable(function, unroll_into)?; + let mut array_cache = Some(ArrayCache::default()); + + while let Some(mut context) = self.unroll_header(function, unroll_into, jump_value)? { + // The inserter's array cache must be explicitly enabled. This is to + // confirm that we're inserting in insertion order. This is true here since: + // 1. We have a fresh inserter for each loop + // 2. Each loop is unrolled in iteration order + // + // Within a loop we do not insert in insertion order. This is fine however since the + // array cache is buffered with a separate fresh_array_cache which collects arrays + // but does not deduplicate. When we later call `into_array_cache`, that will merge + // the fresh cache in with the old one so that each iteration of the loop can cache + // from previous iterations but not the current iteration. + context.inserter.set_array_cache(array_cache, unroll_into); + (unroll_into, jump_value, array_cache) = context.unroll_loop_iteration(); + } + + Ok(()) + } + + /// The loop pre-header is the block that comes before the loop begins. Generally a header block + /// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps + /// back to the beginning. Other predecessors can come from `break` or `continue`. + fn get_pre_header( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result { + let mut pre_header = cfg + .predecessors(self.header) + .filter(|predecessor| *predecessor != self.back_edge_start) + .collect::>(); + + if function.runtime().is_acir() { + assert_eq!(pre_header.len(), 1); + Ok(pre_header.remove(0)) + } else if pre_header.len() == 1 { + Ok(pre_header.remove(0)) + } else { + // We can come back into the header from multiple blocks, so we can't unroll this. + Err(CallStack::new()) + } + } + + /// Unrolls the header block of the loop. This is the block that dominates all other blocks in the + /// loop and contains the jmpif instruction that lets us know if we should continue looping. + /// Returns Some(iteration context) if we should perform another iteration. + fn unroll_header<'a>( + &'a self, + function: &'a mut Function, + unroll_into: BasicBlockId, + induction_value: ValueId, + ) -> Result>, CallStack> { + // We insert into a fresh block first and move instructions into the unroll_into block later + // only once we verify the jmpif instruction has a constant condition. If it does not, we can + // just discard this fresh block and leave the loop unmodified. + let fresh_block = function.dfg.make_block(); + + let mut context = LoopIteration::new(function, self, fresh_block, self.header); + let source_block = &context.dfg()[context.source_block]; + assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header"); + + // Insert the current value of the loop induction variable into our context. + let first_param = source_block.parameters()[0]; + context.inserter.try_map_value(first_param, induction_value); + // Copy over all instructions and a fresh terminator. + context.inline_instructions_from_block(); + // Mutate the terminator if possible so that it points at the iteration block. + match context.dfg()[fresh_block].unwrap_terminator() { + TerminatorInstruction::JmpIf { condition, then_destination, else_destination, call_stack } => { + let condition = *condition; + let next_blocks = context.handle_jmpif(condition, *then_destination, *else_destination, call_stack.clone()); + + // If there is only 1 next block the jmpif evaluated to a single known block. + // This is the expected case and lets us know if we should loop again or not. + if next_blocks.len() == 1 { + context.dfg_mut().inline_block(fresh_block, unroll_into); + + // The fresh block is gone now so we're committing to insert into the original + // unroll_into block from now on. + context.insert_block = unroll_into; + + // In the last iteration, `handle_jmpif` will have replaced `context.source_block` + // with the `else_destination`, that is, the `loop_end`, which signals that we + // have no more loops to unroll, because that block was not part of the loop itself, + // ie. it wasn't between `loop_header` and `loop_body`. Otherwise we have the `loop_body` + // in `source_block` and can unroll that into the destination. + Ok(self.blocks.contains(&context.source_block).then_some(context)) + } else { + // If this case is reached the loop either uses non-constant indices or we need + // another pass, such as mem2reg to resolve them to constants. + Err(context.inserter.function.dfg.get_value_call_stack(condition)) + } + } + other => unreachable!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), + } + } + + /// Find all reference values which were allocated before the pre-header. + /// + /// These are accessible inside the loop body, and they can be involved + /// in load/store operations that could be eliminated if we unrolled the + /// body into the pre-header. + /// + /// Consider this loop: + /// ```text + /// let mut sum = 0; + /// let mut arr = &[]; + /// for i in 0..3 { + /// sum = sum + i; + /// arr.push_back(sum) + /// } + /// sum + /// ``` + /// + /// The SSA has a load+store for the `sum` and a load+push for the `arr`: + /// ```text + /// b0(v0: u32): + /// v2 = allocate -> &mut u32 // reference allocated for `sum` + /// store u32 0 at v2 // initial value for `sum` + /// v4 = allocate -> &mut u32 // reference allocated for the length of `arr` + /// store u32 0 at v4 // initial length of `arr` + /// inc_rc [] of u32 // storage for `arr` + /// v6 = allocate -> &mut [u32] // reference allocated to point at the storage of `arr` + /// store [] of u32 at v6 // initial value for the storage of `arr` + /// jmp b1(u32 0) // start looping from 0 + /// b1(v1: u32): // `i` induction variable + /// v8 = lt v1, u32 3 // loop until 3 + /// jmpif v8 then: b3, else: b2 + /// b3(): + /// v11 = load v2 -> u32 // load `sum` + /// v12 = add v11, v1 // add `i` to `sum` + /// store v12 at v2 // store updated `sum` + /// v13 = load v4 -> u32 // load length of `arr` + /// v14 = load v6 -> [u32] // load storage of `arr` + /// v16, v17 = call slice_push_back(v13, v14, v12) -> (u32, [u32]) // builtin to push, will store to storage and length references + /// v19 = add v1, u32 1 // increase `arr` + /// jmp b1(v19) // back-edge of the loop + /// b2(): // after the loop + /// v9 = load v2 -> u32 // read final value of `sum` + /// ``` + /// + /// We won't always find load _and_ store ops (e.g. the push above doesn't come with a store), + /// but it's likely that mem2reg could eliminate a lot of the loads we can find, so we can + /// use this as an approximation of the gains we would see. + fn find_pre_header_reference_values( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Result, CallStack> { + // We need to traverse blocks from the pre-header up to the block entry point. + let pre_header = self.get_pre_header(function, cfg)?; + let function_entry = function.entry_block(); + + // The algorithm in `find_blocks_in_loop` expects to collect the blocks between the header and the back-edge of the loop, + // but technically works the same if we go from the pre-header up to the function entry as well. + let blocks = Self::find_blocks_in_loop(function_entry, pre_header, cfg).blocks; + + // Collect allocations in all blocks above the header. + let allocations = blocks.iter().flat_map(|b| { + function.dfg[*b] + .instructions() + .iter() + .filter(|i| matches!(&function.dfg[**i], Instruction::Allocate)) + .map(|i| { + // Get the value into which the allocation was stored. + function.dfg.instruction_results(*i)[0] + }) + }); + + // Collect reference parameters of the function itself. + let params = + function.parameters().iter().filter(|p| function.dfg.value_is_reference(**p)).copied(); + + Ok(params.chain(allocations).collect()) + } + + /// Count the number of load and store instructions of specific variables in the loop. + /// + /// Returns `(loads, stores)` in case we want to differentiate in the estimates. + fn count_loads_and_stores( + &self, + function: &Function, + refs: &HashSet, + ) -> (usize, usize) { + let mut loads = 0; + let mut stores = 0; + for block in &self.blocks { + for instruction in function.dfg[*block].instructions() { + match &function.dfg[*instruction] { + Instruction::Load { address } if refs.contains(address) => { + loads += 1; + } + Instruction::Store { address, .. } if refs.contains(address) => { + stores += 1; + } + _ => {} + } + } + } + (loads, stores) + } + + /// Count the number of instructions in the loop, including the terminating jumps. + fn count_all_instructions(&self, function: &Function) -> usize { + self.blocks + .iter() + .map(|block| { + let block = &function.dfg[*block]; + block.instructions().len() + block.terminator().map(|_| 1).unwrap_or_default() + }) + .sum() + } -/// Unroll a single loop in the function. -/// Returns Err(()) if it failed to unroll and Ok(()) otherwise. -fn unroll_loop( - function: &mut Function, - cfg: &ControlFlowGraph, - loop_: &Loop, -) -> Result<(), CallStack> { - let mut unroll_into = get_pre_header(cfg, loop_); - let mut jump_value = get_induction_variable(function, unroll_into)?; + /// Count the number of increments to the induction variable. + /// It should be one, but it can be duplicated. + /// The increment should be in the block where the back-edge was found. + fn count_induction_increments(&self, function: &Function) -> usize { + let back = &function.dfg[self.back_edge_start]; + let header = &function.dfg[self.header]; + let induction_var = header.parameters()[0]; + + back.instructions().iter().filter(|instruction| { + let instruction = &function.dfg[**instruction]; + matches!(instruction, Instruction::Binary(Binary { lhs, operator: BinaryOp::Add, rhs: _ }) if *lhs == induction_var) + }).count() + } - while let Some(context) = unroll_loop_header(function, loop_, unroll_into, jump_value)? { - let (last_block, last_value) = context.unroll_loop_iteration(); - unroll_into = last_block; - jump_value = last_value; + /// Decide if this loop is small enough that it can be inlined in a way that the number + /// of unrolled instructions times the number of iterations would result in smaller bytecode + /// than if we keep the loops with their overheads. + fn is_small_loop(&self, function: &Function, cfg: &ControlFlowGraph) -> bool { + self.boilerplate_stats(function, cfg).map(|s| s.is_small()).unwrap_or_default() } - Ok(()) + /// Collect boilerplate stats if we can figure out the upper and lower bounds of the loop, + /// and the loop doesn't have multiple back-edges from breaks and continues. + fn boilerplate_stats( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Option { + let Ok(Some((lower, upper))) = self.get_const_bounds(function, cfg) else { + return None; + }; + let Some(lower) = lower.try_to_u64() else { + return None; + }; + let Some(upper) = upper.try_to_u64() else { + return None; + }; + let Ok(refs) = self.find_pre_header_reference_values(function, cfg) else { + return None; + }; + let (loads, stores) = self.count_loads_and_stores(function, &refs); + let increments = self.count_induction_increments(function); + let all_instructions = self.count_all_instructions(function); + + Some(BoilerplateStats { + iterations: (upper - lower) as usize, + loads, + stores, + increments, + all_instructions, + }) + } } -/// The loop pre-header is the block that comes before the loop begins. Generally a header block -/// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps -/// back to the beginning. -fn get_pre_header(cfg: &ControlFlowGraph, loop_: &Loop) -> BasicBlockId { - let mut pre_header = cfg - .predecessors(loop_.header) - .filter(|predecessor| *predecessor != loop_.back_edge_start) - .collect::>(); - - assert_eq!(pre_header.len(), 1); - pre_header.remove(0) +/// All the instructions in the following example are boilerplate: +/// ```text +/// brillig(inline) fn main f0 { +/// b0(v0: u32): +/// ... +/// jmp b1(u32 0) +/// b1(v1: u32): +/// v5 = lt v1, u32 4 +/// jmpif v5 then: b3, else: b2 +/// b3(): +/// ... +/// v11 = add v1, u32 1 +/// jmp b1(v11) +/// b2(): +/// ... +/// } +/// ``` +#[derive(Debug)] +struct BoilerplateStats { + /// Number of iterations in the loop. + iterations: usize, + /// Number of loads pre-header references in the loop. + loads: usize, + /// Number of stores into pre-header references in the loop. + stores: usize, + /// Number of increments to the induction variable (might be duplicated). + increments: usize, + /// Number of instructions in the loop, including boilerplate, + /// but excluding the boilerplate which is outside the loop. + all_instructions: usize, +} + +impl BoilerplateStats { + /// Instruction count if we leave the loop as-is. + /// It's the instructions in the loop, plus the one to kick it off in the pre-header. + fn baseline_instructions(&self) -> usize { + self.all_instructions + 1 + } + + /// Estimated number of _useful_ instructions, which is the ones in the loop + /// minus all in-loop boilerplate. + fn useful_instructions(&self) -> usize { + // Two jumps + plus the comparison with the upper bound + let boilerplate = 3; + // Be conservative and only assume that mem2reg gets rid of load followed by store. + // NB we have not checked that these are actual pairs. + let load_and_store = self.loads.min(self.stores) * 2; + self.all_instructions - self.increments - load_and_store - boilerplate + } + + /// Estimated number of instructions if we unroll the loop. + fn unrolled_instructions(&self) -> usize { + self.useful_instructions() * self.iterations + } + + /// A small loop is where if we unroll it into the pre-header then considering the + /// number of iterations we still end up with a smaller bytecode than if we leave + /// the blocks in tact with all the boilerplate involved in jumping, and the extra + /// reference access instructions. + fn is_small(&self) -> bool { + self.unrolled_instructions() < self.baseline_instructions() + } } /// Return the induction value of the current iteration of the loop, from the given block's jmp arguments. /// /// Expects the current block to terminate in `jmp h(N)` where h is the loop header and N is -/// a Field value. +/// a Field value. Returns an `Err` if this isn't the case. +/// +/// Consider the following example: +/// ```text +/// main(): +/// v0 = ... start ... +/// v1 = ... end ... +/// jmp loop_entry(v0) +/// loop_entry(i: Field): +/// ... +/// ``` +/// We're looking for the terminating jump of the `main` predecessor of `loop_entry`. fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result { match function.dfg[block].terminator() { Some(TerminatorInstruction::Jmp { arguments, call_stack: location, .. }) => { @@ -260,54 +737,6 @@ fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result( - function: &'a mut Function, - loop_: &'a Loop, - unroll_into: BasicBlockId, - induction_value: ValueId, -) -> Result>, CallStack> { - // We insert into a fresh block first and move instructions into the unroll_into block later - // only once we verify the jmpif instruction has a constant condition. If it does not, we can - // just discard this fresh block and leave the loop unmodified. - let fresh_block = function.dfg.make_block(); - - let mut context = LoopIteration::new(function, loop_, fresh_block, loop_.header); - let source_block = &context.dfg()[context.source_block]; - assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header"); - - // Insert the current value of the loop induction variable into our context. - let first_param = source_block.parameters()[0]; - context.inserter.try_map_value(first_param, induction_value); - context.inline_instructions_from_block(); - - match context.dfg()[fresh_block].unwrap_terminator() { - TerminatorInstruction::JmpIf { condition, then_destination, else_destination, call_stack } => { - let condition = *condition; - let next_blocks = context.handle_jmpif(condition, *then_destination, *else_destination, call_stack.clone()); - - // If there is only 1 next block the jmpif evaluated to a single known block. - // This is the expected case and lets us know if we should loop again or not. - if next_blocks.len() == 1 { - context.dfg_mut().inline_block(fresh_block, unroll_into); - - // The fresh block is gone now so we're committing to insert into the original - // unroll_into block from now on. - context.insert_block = unroll_into; - - Ok(loop_.blocks.contains(&context.source_block).then_some(context)) - } else { - // If this case is reached the loop either uses non-constant indices or we need - // another pass, such as mem2reg to resolve them to constants. - Err(context.inserter.function.dfg.get_value_call_stack(condition)) - } - } - other => unreachable!("Expected loop header to terminate in a JmpIf to the loop body, but found {other:?} instead"), - } -} - /// The context object for each loop iteration. /// Notably each loop iteration maps each loop block to a fresh, unrolled block. struct LoopIteration<'f> { @@ -357,7 +786,7 @@ impl<'f> LoopIteration<'f> { /// It is expected the terminator instructions are set up to branch into an empty block /// for further unrolling. When the loop is finished this will need to be mutated to /// jump to the end of the loop instead. - fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) { + fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId, Option) { let mut next_blocks = self.unroll_loop_block(); while let Some(block) = next_blocks.pop() { @@ -369,14 +798,20 @@ impl<'f> LoopIteration<'f> { next_blocks.append(&mut blocks); } } + // After having unrolled all blocks in the loop body, we must know how to get back to the header; + // this is also the block into which we have to unroll into next. + let (end_block, induction_value) = self + .induction_value + .expect("Expected to find the induction variable by end of loop iteration"); - self.induction_value - .expect("Expected to find the induction variable by end of loop iteration") + (end_block, induction_value, self.inserter.into_array_cache()) } /// Unroll a single block in the current iteration of the loop fn unroll_loop_block(&mut self) -> Vec { let mut next_blocks = self.unroll_loop_block_helper(); + // Guarantee that the next blocks we set up to be unrolled, are actually part of the loop, + // which we recorded while inlining the instructions of the blocks already processed. next_blocks.retain(|block| { let b = self.get_original_block(*block); self.loop_.blocks.contains(&b) @@ -386,6 +821,7 @@ impl<'f> LoopIteration<'f> { /// Unroll a single block in the current iteration of the loop fn unroll_loop_block_helper(&mut self) -> Vec { + // Copy instructions from the loop body to the unroll destination, replacing the terminator. self.inline_instructions_from_block(); self.visited_blocks.insert(self.source_block); @@ -403,6 +839,7 @@ impl<'f> LoopIteration<'f> { ), TerminatorInstruction::Jmp { destination, arguments, call_stack: _ } => { if self.get_original_block(*destination) == self.loop_.header { + // We found the back-edge of the loop. assert_eq!(arguments.len(), 1); self.induction_value = Some((self.insert_block, arguments[0])); } @@ -414,7 +851,10 @@ impl<'f> LoopIteration<'f> { /// Find the next branch(es) to take from a jmpif terminator and return them. /// If only one block is returned, it means the jmpif condition evaluated to a known - /// constant and we can safely take only the given branch. + /// constant and we can safely take only the given branch. In this case the method + /// also replaces the terminator of the insert block (a.k.a fresh block) to be a `Jmp`, + /// and changes the source block in the context for the next iteration to be the + /// destination indicated by the constant condition (ie. the `then` or the `else`). fn handle_jmpif( &mut self, condition: ValueId, @@ -460,10 +900,13 @@ impl<'f> LoopIteration<'f> { } } + /// Find the original ID of a block that replaced it. fn get_original_block(&self, block: BasicBlockId) -> BasicBlockId { self.original_blocks.get(&block).copied().unwrap_or(block) } + /// Copy over instructions from the source into the insert block, + /// while simplifying instructions and keeping track of original block IDs. fn inline_instructions_from_block(&mut self) { let source_block = &self.dfg()[self.source_block]; let instructions = source_block.instructions().to_vec(); @@ -472,23 +915,31 @@ impl<'f> LoopIteration<'f> { // instances of the induction variable or any values that were changed as a result // of the new induction variable value. for instruction in instructions { - // Skip reference count instructions since they are only used for brillig, and brillig code is not unrolled - if !matches!( - self.dfg()[instruction], - Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } - ) { - self.inserter.push_instruction(instruction, self.insert_block); + // Reference counting is only used by Brillig, ACIR doesn't need them. + if self.inserter.function.runtime().is_acir() && self.is_refcount(instruction) { + continue; } + self.inserter.push_instruction(instruction, self.insert_block); } let mut terminator = self.dfg()[self.source_block] .unwrap_terminator() .clone() .map_values(|value| self.inserter.resolve(value)); + // Replace the blocks in the terminator with fresh one with the same parameters, + // while remembering which were the original block IDs. terminator.mutate_blocks(|block| self.get_or_insert_block(block)); self.inserter.function.dfg.set_block_terminator(self.insert_block, terminator); } + /// Is the instruction an `Rc`? + fn is_refcount(&self, instruction: InstructionId) -> bool { + matches!( + self.dfg()[instruction], + Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } + ) + } + fn dfg(&self) -> &DataFlowGraph { &self.inserter.function.dfg } @@ -500,22 +951,19 @@ impl<'f> LoopIteration<'f> { #[cfg(test)] mod tests { - use crate::{ - errors::RuntimeError, - ssa::{ - function_builder::FunctionBuilder, - ir::{instruction::BinaryOp, map::Id, types::Type}, - }, - }; + use acvm::FieldElement; + + use crate::errors::RuntimeError; + use crate::ssa::{ir::value::ValueId, opt::assert_normalized_ssa_equals, Ssa}; - use super::Ssa; + use super::{BoilerplateStats, Loops}; /// Tries to unroll all loops in each SSA function. /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. - fn try_to_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { + fn try_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { let mut errors = vec![]; for function in ssa.functions.values_mut() { - errors.extend(function.try_to_unroll_loops()); + errors.extend(function.try_unroll_loops()); } (ssa, errors) } @@ -529,166 +977,406 @@ mod tests { // } // } // } - // - // fn main f0 { - // b0(): - // jmp b1(Field 0) - // b1(v0: Field): // header of outer loop - // v1 = lt v0, Field 3 - // jmpif v1, then: b2, else: b3 - // b2(): - // jmp b4(Field 0) - // b4(v2: Field): // header of inner loop - // v3 = lt v2, Field 4 - // jmpif v3, then: b5, else: b6 - // b5(): - // v4 = add v0, v2 - // v5 = lt Field 10, v4 - // constrain v5 - // v6 = add v2, Field 1 - // jmp b4(v6) - // b6(): // end of inner loop - // v7 = add v0, Field 1 - // jmp b1(v7) - // b3(): // end of outer loop - // return Field 0 - // } - let main_id = Id::test_new(0); - - // Compiling main - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); - let b4 = builder.insert_block(); - let b5 = builder.insert_block(); - let b6 = builder.insert_block(); - - let v0 = builder.add_block_parameter(b1, Type::field()); - let v2 = builder.add_block_parameter(b4, Type::field()); - - let zero = builder.field_constant(0u128); - let one = builder.field_constant(1u128); - let three = builder.field_constant(3u128); - let four = builder.field_constant(4u128); - let ten = builder.field_constant(10u128); - - builder.terminate_with_jmp(b1, vec![zero]); - - // b1 - builder.switch_to_block(b1); - let v1 = builder.insert_binary(v0, BinaryOp::Lt, three); - builder.terminate_with_jmpif(v1, b2, b3); - - // b2 - builder.switch_to_block(b2); - builder.terminate_with_jmp(b4, vec![zero]); - - // b3 - builder.switch_to_block(b3); - builder.terminate_with_return(vec![zero]); - - // b4 - builder.switch_to_block(b4); - let v3 = builder.insert_binary(v2, BinaryOp::Lt, four); - builder.terminate_with_jmpif(v3, b5, b6); - - // b5 - builder.switch_to_block(b5); - let v4 = builder.insert_binary(v0, BinaryOp::Add, v2); - let v5 = builder.insert_binary(ten, BinaryOp::Lt, v4); - builder.insert_constrain(v5, one, None); - let v6 = builder.insert_binary(v2, BinaryOp::Add, one); - builder.terminate_with_jmp(b4, vec![v6]); - - // b6 - builder.switch_to_block(b6); - let v7 = builder.insert_binary(v0, BinaryOp::Add, one); - builder.terminate_with_jmp(b1, vec![v7]); - - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 7); - - // Expected output: - // - // fn main f0 { - // b0(): - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // jmp b23() - // b23(): - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // jmp b27() - // b27(): - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // constrain Field 0 - // jmp b31() - // b31(): - // jmp b3() - // b3(): - // return Field 0 - // } + let src = " + acir(inline) fn main f0 { + b0(): + jmp b1(Field 0) + b1(v0: Field): // header of outer loop + v1 = lt v0, Field 3 + jmpif v1 then: b2, else: b3 + b2(): + jmp b4(Field 0) + b4(v2: Field): // header of inner loop + v3 = lt v2, Field 4 + jmpif v3 then: b5, else: b6 + b5(): + v4 = add v0, v2 + v5 = lt Field 10, v4 + constrain v5 == Field 1 + v6 = add v2, Field 1 + jmp b4(v6) + b6(): // end of inner loop + v7 = add v0, Field 1 + jmp b1(v7) + b3(): // end of outer loop + return Field 0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(): + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + jmp b1() + b1(): + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + jmp b2() + b2(): + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + constrain u1 0 == Field 1 + jmp b3() + b3(): + jmp b4() + b4(): + return Field 0 + } + "; + // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. - let (ssa, errors) = try_to_unroll_loops(ssa); + let (ssa, errors) = try_unroll_loops(ssa); assert_eq!(errors.len(), 0, "All loops should be unrolled"); assert_eq!(ssa.main().reachable_blocks().len(), 5); + + assert_normalized_ssa_equals(ssa, expected); } // Test that the pass can still be run on loops which fail to unroll properly #[test] fn fail_to_unroll_loop() { - // fn main f0 { - // b0(v0: Field): - // jmp b1(v0) - // b1(v1: Field): - // v2 = lt v1, 5 - // jmpif v2, then: b2, else: b3 - // b2(): - // v3 = add v1, Field 1 - // jmp b1(v3) - // b3(): - // return Field 0 - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + jmp b1(v0) + b1(v1: Field): + v2 = lt v1, Field 5 + jmpif v2 then: b2, else: b3 + b2(): + v3 = add v1, Field 1 + jmp b1(v3) + b3(): + return Field 0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); + // Sanity check + assert_eq!(ssa.main().reachable_blocks().len(), 4); - let v0 = builder.add_parameter(Type::field()); - let v1 = builder.add_block_parameter(b1, Type::field()); + // Expected that we failed to unroll the loop + let (_, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); + } - builder.terminate_with_jmp(b1, vec![v0]); + #[test] + fn test_get_const_bounds() { + let ssa = brillig_unroll_test_case(); + let function = ssa.main(); + let loops = Loops::find_all(function); + assert_eq!(loops.yet_to_unroll.len(), 1); + + let (lower, upper) = loops.yet_to_unroll[0] + .get_const_bounds(function, &loops.cfg) + .expect("should find bounds") + .expect("bounds are numeric const"); + + assert_eq!(lower, FieldElement::from(0u32)); + assert_eq!(upper, FieldElement::from(4u32)); + } - builder.switch_to_block(b1); - let five = builder.field_constant(5u128); - let v2 = builder.insert_binary(v1, BinaryOp::Lt, five); - builder.terminate_with_jmpif(v2, b2, b3); + #[test] + fn test_find_pre_header_reference_values() { + let ssa = brillig_unroll_test_case(); + let function = ssa.main(); + let mut loops = Loops::find_all(function); + let loop0 = loops.yet_to_unroll.pop().unwrap(); + + let refs = loop0.find_pre_header_reference_values(function, &loops.cfg).unwrap(); + assert_eq!(refs.len(), 1); + assert!(refs.contains(&ValueId::new(2))); + + let (loads, stores) = loop0.count_loads_and_stores(function, &refs); + assert_eq!(loads, 1); + assert_eq!(stores, 1); + + let all = loop0.count_all_instructions(function); + assert_eq!(all, 7); + } - builder.switch_to_block(b2); - let one = builder.field_constant(1u128); - let v3 = builder.insert_binary(v1, BinaryOp::Add, one); - builder.terminate_with_jmp(b1, vec![v3]); + #[test] + fn test_boilerplate_stats() { + let ssa = brillig_unroll_test_case(); + let stats = loop0_stats(&ssa); + assert_eq!(stats.iterations, 4); + assert_eq!(stats.all_instructions, 2 + 5); // Instructions in b1 and b3 + assert_eq!(stats.increments, 1); + assert_eq!(stats.loads, 1); + assert_eq!(stats.stores, 1); + assert_eq!(stats.useful_instructions(), 1); // Adding to sum + assert_eq!(stats.baseline_instructions(), 8); + assert!(stats.is_small()); + } - builder.switch_to_block(b3); - let zero = builder.field_constant(0u128); - builder.terminate_with_return(vec![zero]); + #[test] + fn test_boilerplate_stats_6470() { + let ssa = brillig_unroll_test_case_6470(3); + let stats = loop0_stats(&ssa); + assert_eq!(stats.iterations, 3); + assert_eq!(stats.all_instructions, 2 + 8); // Instructions in b1 and b3 + assert_eq!(stats.increments, 2); + assert_eq!(stats.loads, 1); + assert_eq!(stats.stores, 1); + assert_eq!(stats.useful_instructions(), 3); // array get, add, array set + assert_eq!(stats.baseline_instructions(), 11); + assert!(stats.is_small()); + } - let ssa = builder.finish(); - assert_eq!(ssa.main().reachable_blocks().len(), 4); + /// Test that we can unroll a small loop. + #[test] + fn test_brillig_unroll_small_loop() { + let ssa = brillig_unroll_test_case(); + + // Expectation taken by compiling the Noir program as ACIR, + // ie. by removing the `unconstrained` from `main`. + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v1 = allocate -> &mut u32 + store u32 0 at v1 + v3 = load v1 -> u32 + store v3 at v1 + v4 = load v1 -> u32 + v6 = add v4, u32 1 + store v6 at v1 + v7 = load v1 -> u32 + v9 = add v7, u32 2 + store v9 at v1 + v10 = load v1 -> u32 + v12 = add v10, u32 3 + store v12 at v1 + jmp b1() + b1(): + v13 = load v1 -> u32 + v14 = eq v13, v0 + constrain v13 == v0 + return + } + "; - // Expected that we failed to unroll the loop - let (_, errors) = try_to_unroll_loops(ssa); - assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + + assert_normalized_ssa_equals(ssa, expected); + } + + /// Test that we can unroll the loop in the ticket if we don't have too many iterations. + #[test] + fn test_brillig_unroll_6470_small() { + // Few enough iterations so that we can perform the unroll. + let ssa = brillig_unroll_test_case_6470(3); + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + + // The IDs are shifted by one compared to what the ACIR version printed. + let expected = " + brillig(inline) fn main f0 { + b0(v0: [u64; 6]): + inc_rc v0 + v2 = make_array [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] : [u64; 6] + inc_rc v2 + v3 = allocate -> &mut [u64; 6] + store v2 at v3 + v4 = load v3 -> [u64; 6] + v6 = array_get v0, index u32 0 -> u64 + v8 = add v6, u64 1 + v9 = array_set v4, index u32 0, value v8 + store v9 at v3 + v10 = load v3 -> [u64; 6] + v12 = array_get v0, index u32 1 -> u64 + v13 = add v12, u64 1 + v14 = array_set v10, index u32 1, value v13 + store v14 at v3 + v15 = load v3 -> [u64; 6] + v17 = array_get v0, index u32 2 -> u64 + v18 = add v17, u64 1 + v19 = array_set v15, index u32 2, value v18 + store v19 at v3 + jmp b1() + b1(): + v20 = load v3 -> [u64; 6] + dec_rc v0 + return v20 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + /// Test that with more iterations it's not unrolled. + #[test] + fn test_brillig_unroll_6470_large() { + // More iterations than it can unroll + let parse_ssa = || brillig_unroll_test_case_6470(6); + let ssa = parse_ssa(); + let stats = loop0_stats(&ssa); + assert!(!stats.is_small(), "the loop should be considered large"); + + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_normalized_ssa_equals(ssa, parse_ssa().to_string().as_str()); + } + + /// Test that `break` and `continue` stop unrolling without any panic. + #[test] + fn test_brillig_unroll_break_and_continue() { + // unconstrained fn main() { + // let mut count = 0; + // for i in 0..10 { + // if i == 2 { + // continue; + // } + // if i == 5 { + // break; + // } + // count += 1; + // } + // assert(count == 4); + // } + let src = " + brillig(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + jmp b1(u32 0) + b1(v0: u32): + v5 = lt v0, u32 10 + jmpif v5 then: b2, else: b6 + b2(): + v7 = eq v0, u32 2 + jmpif v7 then: b7, else: b3 + b7(): + v18 = add v0, u32 1 + jmp b1(v18) + b3(): + v9 = eq v0, u32 5 + jmpif v9 then: b5, else: b4 + b5(): + jmp b6() + b6(): + v15 = load v1 -> Field + v17 = eq v15, Field 4 + constrain v15 == Field 4 + return + b4(): + v10 = load v1 -> Field + v12 = add v10, Field 1 + store v12 at v1 + v14 = add v0, u32 1 + jmp b1(v14) + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let (ssa, errors) = try_unroll_loops(ssa); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_normalized_ssa_equals(ssa, src); + } + + /// Simple test loop: + /// ```text + /// unconstrained fn main(sum: u32) { + /// assert(loop(0, 4) == sum); + /// } + /// + /// fn loop(from: u32, to: u32) -> u32 { + /// let mut sum = 0; + /// for i in from..to { + /// sum = sum + i; + /// } + /// sum + /// } + /// ``` + /// We can check what the ACIR unrolling behavior would be by + /// removing the `unconstrained` from the `main` function and + /// compiling the program with `nargo --test-program . compile --show-ssa`. + fn brillig_unroll_test_case() -> Ssa { + let src = " + // After `static_assert` and `assert_constant`: + brillig(inline) fn main f0 { + b0(v0: u32): + v2 = allocate -> &mut u32 + store u32 0 at v2 + jmp b1(u32 0) + b1(v1: u32): + v5 = lt v1, u32 4 + jmpif v5 then: b3, else: b2 + b3(): + v8 = load v2 -> u32 + v9 = add v8, v1 + store v9 at v2 + v11 = add v1, u32 1 + jmp b1(v11) + b2(): + v6 = load v2 -> u32 + v7 = eq v6, v0 + constrain v6 == v0 + return + } + "; + Ssa::from_str(src).unwrap() + } + + /// Test case from #6470: + /// ```text + /// unconstrained fn __validate_gt_remainder(a_u60: [u64; 6]) -> [u64; 6] { + /// let mut result_u60: [u64; 6] = [0; 6]; + /// + /// for i in 0..6 { + /// result_u60[i] = a_u60[i] + 1; + /// } + /// + /// result_u60 + /// } + /// ``` + /// The `num_iterations` parameter can be used to make it more costly to inline. + fn brillig_unroll_test_case_6470(num_iterations: usize) -> Ssa { + let src = format!( + " + // After `static_assert` and `assert_constant`: + brillig(inline) fn main f0 {{ + b0(v0: [u64; 6]): + inc_rc v0 + v3 = make_array [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] : [u64; 6] + inc_rc v3 + v4 = allocate -> &mut [u64; 6] + store v3 at v4 + jmp b1(u32 0) + b1(v1: u32): + v7 = lt v1, u32 {num_iterations} + jmpif v7 then: b3, else: b2 + b3(): + v9 = load v4 -> [u64; 6] + v10 = array_get v0, index v1 -> u64 + v12 = add v10, u64 1 + v13 = array_set v9, index v1, value v12 + v15 = add v1, u32 1 + store v13 at v4 + v16 = add v1, u32 1 // duplicate + jmp b1(v16) + b2(): + v8 = load v4 -> [u64; 6] + dec_rc v0 + return v8 + }} + " + ); + Ssa::from_str(&src).unwrap() + } + + // Boilerplate stats of the first loop in the SSA. + fn loop0_stats(ssa: &Ssa) -> BoilerplateStats { + let function = ssa.main(); + let mut loops = Loops::find_all(function); + let loop0 = loops.yet_to_unroll.pop().expect("there should be a loop"); + loop0.boilerplate_stats(function, &loops.cfg).expect("there should be stats") } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs index f8fe8c68a98..a34b7fd70d3 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs @@ -104,6 +104,11 @@ pub(crate) enum ParsedInstruction { value: ParsedValue, typ: Type, }, + MakeArray { + target: Identifier, + elements: Vec, + typ: Type, + }, Not { target: Identifier, value: ParsedValue, @@ -131,9 +136,8 @@ pub(crate) enum ParsedTerminator { Return(Vec), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum ParsedValue { NumericConstant { constant: FieldElement, typ: Type }, - Array { values: Vec, typ: Type }, Variable(Identifier), } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index 2a94a4fd1eb..552ac0781c7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -1,7 +1,5 @@ use std::collections::HashMap; -use im::Vector; - use crate::ssa::{ function_builder::FunctionBuilder, ir::{basic_block::BasicBlockId, function::FunctionId, value::ValueId}, @@ -27,7 +25,11 @@ struct Translator { /// Maps block names to their IDs blocks: HashMap>, - /// Maps variable names to their IDs + /// Maps variable names to their IDs. + /// + /// This is necessary because the SSA we parse might have undergone some + /// passes already which replaced some of the original IDs. The translator + /// will recreate the SSA step by step, which can result in a new ID layout. variables: HashMap>, } @@ -213,6 +215,14 @@ impl Translator { let value = self.translate_value(value)?; self.builder.increment_array_reference_count(value); } + ParsedInstruction::MakeArray { target, elements, typ } => { + let elements = elements + .into_iter() + .map(|element| self.translate_value(element)) + .collect::>()?; + let value_id = self.builder.insert_make_array(elements, typ); + self.define_variable(target, value_id)?; + } ParsedInstruction::Load { target, value, typ } => { let value = self.translate_value(value)?; let value_id = self.builder.insert_load(value, typ); @@ -255,13 +265,6 @@ impl Translator { ParsedValue::NumericConstant { constant, typ } => { Ok(self.builder.numeric_constant(constant, typ)) } - ParsedValue::Array { values, typ } => { - let mut translated_values = Vector::new(); - for value in values { - translated_values.push_back(self.translate_value(value)?); - } - Ok(self.builder.array_constant(translated_values, typ)) - } ParsedValue::Variable(identifier) => self.lookup_variable(identifier), } } @@ -308,7 +311,13 @@ impl Translator { } fn finish(self) -> Ssa { - self.builder.finish() + let mut ssa = self.builder.finish(); + // Normalize the IDs so we have a better chance of matching the SSA we parsed + // after the step-by-step reconstruction done during translation. This assumes + // that the SSA we parsed was printed by the `SsaBuilder`, which normalizes + // before each print. + ssa.normalize_ids(); + ssa } fn current_function_id(&self) -> FunctionId { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs similarity index 97% rename from noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser.rs rename to noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs index 11d43284786..2db2c636a8f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs @@ -429,6 +429,15 @@ impl<'a> Parser<'a> { return Ok(ParsedInstruction::Load { target, value, typ }); } + if self.eat_keyword(Keyword::MakeArray)? { + self.eat_or_error(Token::LeftBracket)?; + let elements = self.parse_comma_separated_values()?; + self.eat_or_error(Token::RightBracket)?; + self.eat_or_error(Token::Colon)?; + let typ = self.parse_type()?; + return Ok(ParsedInstruction::MakeArray { target, elements, typ }); + } + if self.eat_keyword(Keyword::Not)? { let value = self.parse_value_or_error()?; return Ok(ParsedInstruction::Not { target, value }); @@ -557,10 +566,6 @@ impl<'a> Parser<'a> { return Ok(Some(value)); } - if let Some(value) = self.parse_array_value()? { - return Ok(Some(value)); - } - if let Some(identifier) = self.eat_identifier()? { return Ok(Some(ParsedValue::Variable(identifier))); } @@ -590,23 +595,6 @@ impl<'a> Parser<'a> { } } - fn parse_array_value(&mut self) -> ParseResult> { - if self.eat(Token::LeftBracket)? { - let values = self.parse_comma_separated_values()?; - self.eat_or_error(Token::RightBracket)?; - self.eat_or_error(Token::Keyword(Keyword::Of))?; - let types = self.parse_types()?; - let types_len = types.len(); - let values_len = values.len(); - Ok(Some(ParsedValue::Array { - typ: Type::Array(Arc::new(types), values_len / types_len), - values, - })) - } else { - Ok(None) - } - } - fn parse_types(&mut self) -> ParseResult> { if self.eat(Token::LeftParen)? { let types = self.parse_comma_separated_types()?; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs index 9205353151e..60d398bf9d5 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs @@ -54,33 +54,36 @@ fn test_return_integer() { } #[test] -fn test_return_array() { +fn test_make_array() { let src = " acir(inline) fn main f0 { b0(): - return [Field 1] of Field + v1 = make_array [Field 1] : [Field; 1] + return v1 } "; assert_ssa_roundtrip(src); } #[test] -fn test_return_empty_array() { +fn test_make_empty_array() { let src = " acir(inline) fn main f0 { b0(): - return [] of Field + v0 = make_array [] : [Field; 0] + return v0 } "; assert_ssa_roundtrip(src); } #[test] -fn test_return_composite_array() { +fn test_make_composite_array() { let src = " acir(inline) fn main f0 { b0(): - return [Field 1, Field 2] of (Field, Field) + v2 = make_array [Field 1, Field 2] : [(Field, Field); 1] + return v2 } "; assert_ssa_roundtrip(src); @@ -103,8 +106,8 @@ fn test_multiple_blocks_and_jmp() { acir(inline) fn main f0 { b0(): jmp b1(Field 1) - b1(v1: Field): - return v1 + b1(v0: Field): + return v0 } "; assert_ssa_roundtrip(src); @@ -115,11 +118,11 @@ fn test_jmpif() { let src = " acir(inline) fn main f0 { b0(v0: Field): - jmpif v0 then: b1, else: b2 - b1(): - return + jmpif v0 then: b2, else: b1 b2(): return + b1(): + return } "; assert_ssa_roundtrip(src); @@ -151,7 +154,9 @@ fn test_call_multiple_return_values() { } acir(inline) fn foo f1 { b0(): - return [Field 1, Field 2, Field 3] of Field, [Field 4] of Field + v3 = make_array [Field 1, Field 2, Field 3] : [Field; 3] + v5 = make_array [Field 4] : [Field; 1] + return v3, v5 } "; assert_ssa_roundtrip(src); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/token.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/token.rs index d648f58de41..f663879e899 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/token.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/token.rs @@ -136,6 +136,7 @@ pub(crate) enum Keyword { Jmpif, Load, Lt, + MakeArray, MaxBitSize, Mod, Mul, @@ -190,6 +191,7 @@ impl Keyword { "jmpif" => Keyword::Jmpif, "load" => Keyword::Load, "lt" => Keyword::Lt, + "make_array" => Keyword::MakeArray, "max_bit_size" => Keyword::MaxBitSize, "mod" => Keyword::Mod, "mul" => Keyword::Mul, @@ -248,6 +250,7 @@ impl Display for Keyword { Keyword::Jmpif => write!(f, "jmpif"), Keyword::Load => write!(f, "load"), Keyword::Lt => write!(f, "lt"), + Keyword::MakeArray => write!(f, "make_array"), Keyword::MaxBitSize => write!(f, "max_bit_size"), Keyword::Mod => write!(f, "mod"), Keyword::Mul => write!(f, "mul"), diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index 96e779482a4..c50f0a7f45c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -291,7 +291,7 @@ impl<'a> FunctionContext<'a> { }); } - self.builder.array_constant(array, typ).into() + self.builder.insert_make_array(array, typ).into() } fn codegen_block(&mut self, block: &[Expression]) -> Result { @@ -466,6 +466,7 @@ impl<'a> FunctionContext<'a> { /// /// For example, the loop `for i in start .. end { body }` is codegen'd as: /// + /// ```text /// v0 = ... codegen start ... /// v1 = ... codegen end ... /// br loop_entry(v0) @@ -478,6 +479,7 @@ impl<'a> FunctionContext<'a> { /// br loop_entry(v4) /// loop_end(): /// ... This is the current insert point after codegen_for finishes ... + /// ``` fn codegen_for(&mut self, for_expr: &ast::For) -> Result { let loop_entry = self.builder.insert_block(); let loop_body = self.builder.insert_block(); @@ -529,6 +531,7 @@ impl<'a> FunctionContext<'a> { /// /// For example, the expression `if cond { a } else { b }` is codegen'd as: /// + /// ```text /// v0 = ... codegen cond ... /// brif v0, then: then_block, else: else_block /// then_block(): @@ -539,16 +542,19 @@ impl<'a> FunctionContext<'a> { /// br end_if(v2) /// end_if(v3: ?): // Type of v3 matches the type of a and b /// ... This is the current insert point after codegen_if finishes ... + /// ``` /// /// As another example, the expression `if cond { a }` is codegen'd as: /// + /// ```text /// v0 = ... codegen cond ... - /// brif v0, then: then_block, else: end_block + /// brif v0, then: then_block, else: end_if /// then_block: /// v1 = ... codegen a ... /// br end_if() /// end_if: // No block parameter is needed. Without an else, the unit value is always returned. /// ... This is the current insert point after codegen_if finishes ... + /// ``` fn codegen_if(&mut self, if_expr: &ast::If) -> Result { let condition = self.codegen_non_tuple_expression(&if_expr.condition)?; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs index 723df775b1e..475e3ff1be9 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/traits.rs @@ -24,6 +24,7 @@ pub struct NoirTrait { pub items: Vec>, pub attributes: Vec, pub visibility: ItemVisibility, + pub is_alias: bool, } /// Any declaration inside the body of a trait that a user is required to @@ -77,6 +78,9 @@ pub struct NoirTraitImpl { pub where_clause: Vec, pub items: Vec>, + + /// true if generated at compile-time, e.g. from a trait alias + pub is_synthetic: bool, } /// Represents a simple trait constraint such as `where Foo: TraitY` @@ -130,12 +134,19 @@ impl Display for TypeImpl { } } +// TODO: display where clauses (follow-up issue) impl Display for NoirTrait { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let generics = vecmap(&self.generics, |generic| generic.to_string()); let generics = if generics.is_empty() { "".into() } else { generics.join(", ") }; write!(f, "trait {}{}", self.name, generics)?; + + if self.is_alias { + let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); + return write!(f, " = {};", bounds); + } + if !self.bounds.is_empty() { let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); write!(f, ": {}", bounds)?; @@ -222,6 +233,11 @@ impl Display for TraitBound { impl Display for NoirTraitImpl { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Synthetic NoirTraitImpl's don't get printed + if self.is_synthetic { + return Ok(()); + } + write!(f, "impl")?; if !self.impl_generics.is_empty() { write!( diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/errors.rs index a63601a4280..a6b6120986e 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -208,7 +208,7 @@ pub enum TypeCheckError { #[derive(Debug, Clone, PartialEq, Eq)] pub struct NoMatchingImplFoundError { - constraints: Vec<(Type, String)>, + pub(crate) constraints: Vec<(Type, String)>, pub span: Span, } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs index 63471efac43..bcb4ce1c616 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs @@ -95,6 +95,8 @@ pub enum ParserErrorReason { AssociatedTypesNotAllowedInPaths, #[error("Associated types are not allowed on a method call")] AssociatedTypesNotAllowedInMethodCalls, + #[error("Empty trait alias")] + EmptyTraitAlias, #[error( "Wrong number of arguments for attribute `{}`. Expected {}, found {}", name, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs index 9215aec2742..8e6b3bae0e9 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/impls.rs @@ -113,6 +113,7 @@ impl<'a> Parser<'a> { let object_type = self.parse_type_or_error(); let where_clause = self.parse_where_clause(); let items = self.parse_trait_impl_body(); + let is_synthetic = false; NoirTraitImpl { impl_generics, @@ -121,6 +122,7 @@ impl<'a> Parser<'a> { object_type, where_clause, items, + is_synthetic, } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs index 4fbcd7abac5..ce712b559d8 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs @@ -1,3 +1,5 @@ +use iter_extended::vecmap; + use crate::{ parser::{labels::ParsingRuleLabel, Item, ItemKind}, token::{Keyword, Token}, @@ -13,31 +15,32 @@ impl<'a> Parser<'a> { } pub(crate) fn parse_module_items(&mut self, nested: bool) -> Vec { - self.parse_many("items", without_separator(), |parser| { + self.parse_many_to_many("items", without_separator(), |parser| { parser.parse_module_item_in_list(nested) }) } - fn parse_module_item_in_list(&mut self, nested: bool) -> Option { + fn parse_module_item_in_list(&mut self, nested: bool) -> Vec { loop { // We only break out of the loop on `}` if we are inside a `mod { ..` if nested && self.at(Token::RightBrace) { - return None; + return vec![]; } // We always break on EOF (we don't error because if we are inside `mod { ..` // the outer parsing logic will error instead) if self.at_eof() { - return None; + return vec![]; } - let Some(item) = self.parse_item() else { + let parsed_items = self.parse_item(); + if parsed_items.is_empty() { // If we couldn't parse an item we check which token we got match self.token.token() { Token::RightBrace if nested => { - return None; + return vec![]; } - Token::EOF => return None, + Token::EOF => return vec![], _ => (), } @@ -47,7 +50,7 @@ impl<'a> Parser<'a> { continue; }; - return Some(item); + return parsed_items; } } @@ -85,15 +88,19 @@ impl<'a> Parser<'a> { } /// Item = OuterDocComments ItemKind - fn parse_item(&mut self) -> Option { + fn parse_item(&mut self) -> Vec { let start_span = self.current_token_span; let doc_comments = self.parse_outer_doc_comments(); - let kind = self.parse_item_kind()?; + let kinds = self.parse_item_kind(); let span = self.span_since(start_span); - Some(Item { kind, span, doc_comments }) + vecmap(kinds, |kind| Item { kind, span, doc_comments: doc_comments.clone() }) } + /// This method returns one 'ItemKind' in the majority of cases. + /// The current exception is when parsing a trait alias, + /// which returns both the trait and the impl. + /// /// ItemKind /// = InnerAttribute /// | Attributes Modifiers @@ -106,9 +113,9 @@ impl<'a> Parser<'a> { /// | TypeAlias /// | Function /// ) - fn parse_item_kind(&mut self) -> Option { + fn parse_item_kind(&mut self) -> Vec { if let Some(kind) = self.parse_inner_attribute() { - return Some(ItemKind::InnerAttribute(kind)); + return vec![ItemKind::InnerAttribute(kind)]; } let start_span = self.current_token_span; @@ -122,78 +129,81 @@ impl<'a> Parser<'a> { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); let use_tree = self.parse_use_tree(); - return Some(ItemKind::Import(use_tree, modifiers.visibility)); + return vec![ItemKind::Import(use_tree, modifiers.visibility)]; } if let Some(is_contract) = self.eat_mod_or_contract() { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(self.parse_mod_or_contract(attributes, is_contract, modifiers.visibility)); + return vec![self.parse_mod_or_contract(attributes, is_contract, modifiers.visibility)]; } if self.eat_keyword(Keyword::Struct) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::Struct(self.parse_struct( + return vec![ItemKind::Struct(self.parse_struct( attributes, modifiers.visibility, start_span, - ))); + ))]; } if self.eat_keyword(Keyword::Impl) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(match self.parse_impl() { + return vec![match self.parse_impl() { Impl::Impl(type_impl) => ItemKind::Impl(type_impl), Impl::TraitImpl(noir_trait_impl) => ItemKind::TraitImpl(noir_trait_impl), - }); + }]; } if self.eat_keyword(Keyword::Trait) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::Trait(self.parse_trait( - attributes, - modifiers.visibility, - start_span, - ))); + let (noir_trait, noir_impl) = + self.parse_trait(attributes, modifiers.visibility, start_span); + let mut output = vec![ItemKind::Trait(noir_trait)]; + if let Some(noir_impl) = noir_impl { + output.push(ItemKind::TraitImpl(noir_impl)); + } + + return output; } if self.eat_keyword(Keyword::Global) { self.unconstrained_not_applicable(modifiers); - return Some(ItemKind::Global( + return vec![ItemKind::Global( self.parse_global( attributes, modifiers.comptime.is_some(), modifiers.mutable.is_some(), ), modifiers.visibility, - )); + )]; } if self.eat_keyword(Keyword::Type) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::TypeAlias( + return vec![ItemKind::TypeAlias( self.parse_type_alias(modifiers.visibility, start_span), - )); + )]; } if self.eat_keyword(Keyword::Fn) { self.mutable_not_applicable(modifiers); - return Some(ItemKind::Function(self.parse_function( + return vec![ItemKind::Function(self.parse_function( attributes, modifiers.visibility, modifiers.comptime.is_some(), modifiers.unconstrained.is_some(), false, // allow_self - ))); + ))]; } - None + vec![] } fn eat_mod_or_contract(&mut self) -> Option { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/parse_many.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/parse_many.rs index ea4dfe97122..be156eb1618 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/parse_many.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/parse_many.rs @@ -18,6 +18,19 @@ impl<'a> Parser<'a> { self.parse_many_return_trailing_separator_if_any(items, separated_by, f).0 } + /// parse_many, where the given function `f` may return multiple results + pub(super) fn parse_many_to_many( + &mut self, + items: &'static str, + separated_by: SeparatedBy, + f: F, + ) -> Vec + where + F: FnMut(&mut Parser<'a>) -> Vec, + { + self.parse_many_to_many_return_trailing_separator_if_any(items, separated_by, f).0 + } + /// Same as parse_many, but returns a bool indicating whether a trailing separator was found. pub(super) fn parse_many_return_trailing_separator_if_any( &mut self, @@ -27,6 +40,26 @@ impl<'a> Parser<'a> { ) -> (Vec, bool) where F: FnMut(&mut Parser<'a>) -> Option, + { + let f = |x: &mut Parser<'a>| { + if let Some(result) = f(x) { + vec![result] + } else { + vec![] + } + }; + self.parse_many_to_many_return_trailing_separator_if_any(items, separated_by, f) + } + + /// Same as parse_many, but returns a bool indicating whether a trailing separator was found. + fn parse_many_to_many_return_trailing_separator_if_any( + &mut self, + items: &'static str, + separated_by: SeparatedBy, + mut f: F, + ) -> (Vec, bool) + where + F: FnMut(&mut Parser<'a>) -> Vec, { let mut elements: Vec = Vec::new(); let mut trailing_separator = false; @@ -38,12 +71,13 @@ impl<'a> Parser<'a> { } let start_span = self.current_token_span; - let Some(element) = f(self) else { + let mut new_elements = f(self); + if new_elements.is_empty() { if let Some(end) = &separated_by.until { self.eat(end.clone()); } break; - }; + } if let Some(separator) = &separated_by.token { if !trailing_separator && !elements.is_empty() { @@ -51,7 +85,7 @@ impl<'a> Parser<'a> { } } - elements.push(element); + elements.append(&mut new_elements); trailing_separator = if let Some(separator) = &separated_by.token { self.eat(separator.clone()) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs index fead6a34c82..e03b629e9ea 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/traits.rs @@ -1,9 +1,14 @@ +use iter_extended::vecmap; + use noirc_errors::Span; -use crate::ast::{Documented, ItemVisibility, NoirTrait, Pattern, TraitItem, UnresolvedType}; +use crate::ast::{ + Documented, GenericTypeArg, GenericTypeArgs, ItemVisibility, NoirTrait, Path, Pattern, + TraitItem, UnresolvedGeneric, UnresolvedTraitConstraint, UnresolvedType, +}; use crate::{ ast::{Ident, UnresolvedTypeData}, - parser::{labels::ParsingRuleLabel, ParserErrorReason}, + parser::{labels::ParsingRuleLabel, NoirTraitImpl, ParserErrorReason}, token::{Attribute, Keyword, SecondaryAttribute, Token}, }; @@ -12,34 +17,117 @@ use super::Parser; impl<'a> Parser<'a> { /// Trait = 'trait' identifier Generics ( ':' TraitBounds )? WhereClause TraitBody + /// | 'trait' identifier Generics '=' TraitBounds WhereClause ';' pub(crate) fn parse_trait( &mut self, attributes: Vec<(Attribute, Span)>, visibility: ItemVisibility, start_span: Span, - ) -> NoirTrait { + ) -> (NoirTrait, Option) { let attributes = self.validate_secondary_attributes(attributes); let Some(name) = self.eat_ident() else { self.expected_identifier(); - return empty_trait(attributes, visibility, self.span_since(start_span)); + let noir_trait = empty_trait(attributes, visibility, self.span_since(start_span)); + let no_implicit_impl = None; + return (noir_trait, no_implicit_impl); }; let generics = self.parse_generics(); - let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; - let where_clause = self.parse_where_clause(); - let items = self.parse_trait_body(); - NoirTrait { + // Trait aliases: + // trait Foo<..> = A + B + E where ..; + let (bounds, where_clause, items, is_alias) = if self.eat_assign() { + let bounds = self.parse_trait_bounds(); + + if bounds.is_empty() { + self.push_error(ParserErrorReason::EmptyTraitAlias, self.previous_token_span); + } + + let where_clause = self.parse_where_clause(); + let items = Vec::new(); + if !self.eat_semicolon() { + self.expected_token(Token::Semicolon); + } + + let is_alias = true; + (bounds, where_clause, items, is_alias) + } else { + let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; + let where_clause = self.parse_where_clause(); + let items = self.parse_trait_body(); + let is_alias = false; + (bounds, where_clause, items, is_alias) + }; + + let span = self.span_since(start_span); + + let noir_impl = is_alias.then(|| { + let object_type_ident = Ident::new("#T".to_string(), span); + let object_type_path = Path::from_ident(object_type_ident.clone()); + let object_type_generic = UnresolvedGeneric::Variable(object_type_ident); + + let is_synthesized = true; + let object_type = UnresolvedType { + typ: UnresolvedTypeData::Named(object_type_path, vec![].into(), is_synthesized), + span, + }; + + let mut impl_generics = generics.clone(); + impl_generics.push(object_type_generic); + + let trait_name = Path::from_ident(name.clone()); + let trait_generics: GenericTypeArgs = vecmap(generics.clone(), |generic| { + let is_synthesized = true; + let generic_type = UnresolvedType { + typ: UnresolvedTypeData::Named( + Path::from_ident(generic.ident().clone()), + vec![].into(), + is_synthesized, + ), + span, + }; + + GenericTypeArg::Ordered(generic_type) + }) + .into(); + + // bounds from trait + let mut where_clause = where_clause.clone(); + for bound in bounds.clone() { + where_clause.push(UnresolvedTraitConstraint { + typ: object_type.clone(), + trait_bound: bound, + }); + } + + let items = vec![]; + let is_synthetic = true; + + NoirTraitImpl { + impl_generics, + trait_name, + trait_generics, + object_type, + where_clause, + items, + is_synthetic, + } + }); + + let noir_trait = NoirTrait { name, generics, bounds, where_clause, - span: self.span_since(start_span), + span, items, attributes, visibility, - } + is_alias, + }; + + (noir_trait, noir_impl) } /// TraitBody = '{' ( OuterDocComments TraitItem )* '}' @@ -188,28 +276,51 @@ fn empty_trait( items: Vec::new(), attributes, visibility, + is_alias: false, } } #[cfg(test)] mod tests { use crate::{ - ast::{NoirTrait, TraitItem}, + ast::{NoirTrait, NoirTraitImpl, TraitItem}, parser::{ - parser::{parse_program, tests::expect_no_errors}, + parser::{parse_program, tests::expect_no_errors, ParserErrorReason}, ItemKind, }, }; - fn parse_trait_no_errors(src: &str) -> NoirTrait { + fn parse_trait_opt_impl_no_errors(src: &str) -> (NoirTrait, Option) { let (mut module, errors) = parse_program(src); expect_no_errors(&errors); - assert_eq!(module.items.len(), 1); - let item = module.items.remove(0); + let (item, impl_item) = if module.items.len() == 2 { + let item = module.items.remove(0); + let impl_item = module.items.remove(0); + (item, Some(impl_item)) + } else { + assert_eq!(module.items.len(), 1); + let item = module.items.remove(0); + (item, None) + }; let ItemKind::Trait(noir_trait) = item.kind else { panic!("Expected trait"); }; - noir_trait + let noir_trait_impl = impl_item.map(|impl_item| { + let ItemKind::TraitImpl(noir_trait_impl) = impl_item.kind else { + panic!("Expected impl"); + }; + noir_trait_impl + }); + (noir_trait, noir_trait_impl) + } + + fn parse_trait_with_impl_no_errors(src: &str) -> (NoirTrait, NoirTraitImpl) { + let (noir_trait, noir_trait_impl) = parse_trait_opt_impl_no_errors(src); + (noir_trait, noir_trait_impl.expect("expected a NoirTraitImpl")) + } + + fn parse_trait_no_errors(src: &str) -> NoirTrait { + parse_trait_opt_impl_no_errors(src).0 } #[test] @@ -220,6 +331,15 @@ mod tests { assert!(noir_trait.generics.is_empty()); assert!(noir_trait.where_clause.is_empty()); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias() { + let src = "trait Foo = ;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -230,6 +350,50 @@ mod tests { assert_eq!(noir_trait.generics.len(), 2); assert!(noir_trait.where_clause.is_empty()); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias_with_generics() { + let src = "trait Foo = Bar + Baz;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.name.to_string(), "Foo"); + assert_eq!(noir_trait_alias.generics.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert!(noir_trait_alias.where_clause.is_empty()); + assert!(noir_trait_alias.items.is_empty()); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 3); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(noir_trait_impl.where_clause.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.is_empty(), noir_trait_alias.where_clause.is_empty()); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias_with_generics() { + let src = "trait Foo = ;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -240,6 +404,54 @@ mod tests { assert_eq!(noir_trait.generics.len(), 2); assert_eq!(noir_trait.where_clause.len(), 1); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias_with_where_clause() { + let src = "trait Foo = Bar + Baz where A: Z;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.name.to_string(), "Foo"); + assert_eq!(noir_trait_alias.generics.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert_eq!(noir_trait_alias.where_clause.len(), 1); + assert!(noir_trait_alias.items.is_empty()); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 3); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(noir_trait_impl.where_clause.len(), 3); + assert_eq!(noir_trait_impl.where_clause[0].to_string(), "A: Z"); + assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Bar"); + assert_eq!(noir_trait_impl.where_clause[2].to_string(), "#T: Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz where A: Z {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.len(), noir_trait_alias.where_clause.len()); + assert_eq!( + noir_trait.where_clause[0].to_string(), + noir_trait_alias.where_clause[0].to_string() + ); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias_with_where_clause() { + let src = "trait Foo = where A: Z;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -253,6 +465,7 @@ mod tests { panic!("Expected type"); }; assert_eq!(name.to_string(), "Elem"); + assert!(!noir_trait.is_alias); } #[test] @@ -268,6 +481,7 @@ mod tests { assert_eq!(name.to_string(), "x"); assert_eq!(typ.to_string(), "Field"); assert_eq!(default_value.unwrap().to_string(), "1"); + assert!(!noir_trait.is_alias); } #[test] @@ -281,6 +495,7 @@ mod tests { panic!("Expected function"); }; assert!(body.is_none()); + assert!(!noir_trait.is_alias); } #[test] @@ -294,6 +509,7 @@ mod tests { panic!("Expected function"); }; assert!(body.is_some()); + assert!(!noir_trait.is_alias); } #[test] @@ -306,5 +522,39 @@ mod tests { assert_eq!(noir_trait.bounds[1].to_string(), "Baz"); assert_eq!(noir_trait.to_string(), "trait Foo: Bar + Baz {\n}"); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias() { + let src = "trait Foo = Bar + Baz;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.bounds.len(), 2); + + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + + assert_eq!(noir_trait_alias.to_string(), "trait Foo = Bar + Baz;"); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 1); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 0); + assert_eq!(noir_trait_impl.where_clause.len(), 2); + assert_eq!(noir_trait_impl.where_clause[0].to_string(), "#T: Bar"); + assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.is_empty(), noir_trait_alias.where_clause.is_empty()); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs index 0adf5c90bea..811a32bab86 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs @@ -1,5 +1,6 @@ use crate::hir::def_collector::dc_crate::CompilationError; use crate::hir::resolution::errors::ResolverError; +use crate::hir::type_check::TypeCheckError; use crate::tests::{get_program_errors, get_program_with_maybe_parser_errors}; use super::assert_no_errors; @@ -320,6 +321,251 @@ fn regression_6314_double_inheritance() { assert_no_errors(src); } +#[test] +fn trait_alias_single_member() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Baz = Foo; + + impl Foo for Field { + fn foo(self) -> Self { self } + } + + fn baz(x: T) -> T where T: Baz { + x.foo() + } + + fn main() { + let x: Field = 0; + let _ = baz(x); + } + "#; + assert_no_errors(src); +} + +#[test] +fn trait_alias_two_members() { + let src = r#" + pub trait Foo { + fn foo(self) -> Self; + } + + pub trait Bar { + fn bar(self) -> Self; + } + + pub trait Baz = Foo + Bar; + + fn baz(x: T) -> T where T: Baz { + x.foo().bar() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> Self { + self + 2 + } + } + + fn main() { + assert(0.foo().bar() == baz(0)); + }"#; + + assert_no_errors(src); +} + +#[test] +fn trait_alias_polymorphic_inheritance() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Bar { + fn bar(self) -> T; + } + + trait Baz = Foo + Bar; + + fn baz(x: T) -> U where T: Baz { + x.foo().bar() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> bool { + true + } + } + + fn main() { + assert(0.foo().bar() == baz(0)); + }"#; + + assert_no_errors(src); +} + +// TODO(https://github.com/noir-lang/noir/issues/6467): currently fails with the +// same errors as the desugared version +#[test] +fn trait_alias_polymorphic_where_clause() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Bar { + fn bar(self) -> T; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux = Foo + Bar where T: Baz; + + fn qux(x: T) -> bool where T: Qux { + x.foo().bar().baz() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> bool { + true + } + } + + impl Baz for bool { + fn baz(self) -> bool { + self + } + } + + fn main() { + assert(0.foo().bar().baz() == qux(0)); + } + "#; + + // TODO(https://github.com/noir-lang/noir/issues/6467) + // assert_no_errors(src); + let errors = get_program_errors(src); + assert_eq!(errors.len(), 2); + + match &errors[0].0 { + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name, .. + }) => { + assert_eq!(method_name, "baz"); + } + other => { + panic!("expected UnresolvedMethodCall, but found {:?}", other); + } + } + + match &errors[1].0 { + CompilationError::TypeError(TypeCheckError::NoMatchingImplFound(err)) => { + assert_eq!(err.constraints.len(), 2); + assert_eq!(err.constraints[0].1, "Baz"); + assert_eq!(err.constraints[1].1, "Qux<_>"); + } + other => { + panic!("expected NoMatchingImplFound, but found {:?}", other); + } + } +} + +// TODO(https://github.com/noir-lang/noir/issues/6467): currently failing, so +// this just tests that the trait alias has an equivalent error to the expected +// desugared version +#[test] +fn trait_alias_with_where_clause_has_equivalent_errors() { + let src = r#" + trait Bar { + fn bar(self) -> Self; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux: Bar where T: Baz {} + + impl Qux for U where + U: Bar, + T: Baz, + {} + + pub fn qux(x: T, _: U) -> bool where U: Qux { + x.baz() + } + + fn main() {} + "#; + + let alias_src = r#" + trait Bar { + fn bar(self) -> Self; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux = Bar where T: Baz; + + pub fn qux(x: T, _: U) -> bool where U: Qux { + x.baz() + } + + fn main() {} + "#; + + let errors = get_program_errors(src); + let alias_errors = get_program_errors(alias_src); + + assert_eq!(errors.len(), 1); + assert_eq!(alias_errors.len(), 1); + + match (&errors[0].0, &alias_errors[0].0) { + ( + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name, + object_type, + .. + }), + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name: alias_method_name, + object_type: alias_object_type, + .. + }), + ) => { + assert_eq!(method_name, alias_method_name); + assert_eq!(object_type, alias_object_type); + } + other => { + panic!("expected UnresolvedMethodCall, but found {:?}", other); + } + } +} + #[test] fn removes_assumed_parent_traits_after_function_ends() { let src = r#" diff --git a/noir/noir-repo/docs/docs/noir/concepts/traits.md b/noir/noir-repo/docs/docs/noir/concepts/traits.md index 9da00a77587..b6c0a886eb0 100644 --- a/noir/noir-repo/docs/docs/noir/concepts/traits.md +++ b/noir/noir-repo/docs/docs/noir/concepts/traits.md @@ -490,12 +490,95 @@ trait CompSciStudent: Programmer + Student { } ``` +### Trait Aliases + +Similar to the proposed Rust feature for [trait aliases](https://github.com/rust-lang/rust/blob/4d215e2426d52ca8d1af166d5f6b5e172afbff67/src/doc/unstable-book/src/language-features/trait-alias.md), +Noir supports aliasing one or more traits and using those aliases wherever +traits would normally be used. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> Self; +} + +// Equivalent to: +// trait Baz: Foo + Bar {} +// +// impl Baz for T where T: Foo + Bar {} +trait Baz = Foo + Bar; + +// We can use `Baz` to refer to `Foo + Bar` +fn baz(x: T) -> T where T: Baz { + x.foo().bar() +} +``` + +#### Generic Trait Aliases + +Trait aliases can also be generic by placing the generic arguments after the +trait name. These generics are in scope of every item within the trait alias. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> T; +} + +// Equivalent to: +// trait Baz: Foo + Bar {} +// +// impl Baz for U where U: Foo + Bar {} +trait Baz = Foo + Bar; +``` + +#### Trait Alias Where Clauses + +Trait aliases support where clauses to add trait constraints to any of their +generic arguments, e.g. ensuring `T: Baz` for a trait alias `Qux`. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> T; +} + +trait Baz { + fn baz(self) -> bool; +} + +// Equivalent to: +// trait Qux: Foo + Bar where T: Baz {} +// +// impl Qux for U where +// U: Foo + Bar, +// T: Baz, +// {} +trait Qux = Foo + Bar where T: Baz; +``` + +Note that while trait aliases support where clauses, +the equivalent traits can fail due to [#6467](https://github.com/noir-lang/noir/issues/6467) + ### Visibility -By default, like functions, traits are private to the module they exist in. You can use `pub` -to make the trait public or `pub(crate)` to make it public to just its crate: +By default, like functions, traits and trait aliases are private to the module +they exist in. You can use `pub` to make the trait public or `pub(crate)` to make +it public to just its crate: ```rust // This trait is now public pub trait Trait {} -``` \ No newline at end of file + +// This trait alias is now public +pub trait Baz = Foo + Bar; +``` diff --git a/noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/Nargo.toml new file mode 100644 index 00000000000..9c9bd8de04a --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "embedded_curve_msm_simplification" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/src/main.nr new file mode 100644 index 00000000000..e5aaa0f4d15 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/embedded_curve_msm_simplification/src/main.nr @@ -0,0 +1,12 @@ +fn main() { + let pub_x = 0x0000000000000000000000000000000000000000000000000000000000000001; + let pub_y = 0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c; + + let g1_y = 17631683881184975370165255887551781615748388533673675138860; + let g1 = std::embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: g1_y, is_infinite: false }; + let scalar = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 1, hi: 0 }; + // Test that multi_scalar_mul correctly derives the public key + let res = std::embedded_curve_ops::multi_scalar_mul([g1], [scalar]); + assert(res.x == pub_x); + assert(res.y == pub_y); +} diff --git a/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Nargo.toml b/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Nargo.toml new file mode 100644 index 00000000000..f7076311e1d --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_brillig_performance_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Prover.toml b/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Prover.toml new file mode 100644 index 00000000000..5bb7f354257 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/Prover.toml @@ -0,0 +1,16 @@ +input_amount = "1" +minimum_output_amount = "2" +secret_hash_for_L1_to_l2_message = "3" +uniswap_fee_tier = "4" + +[aztec_recipient] +inner = "5" + +[caller_on_L1] +inner = "6" + +[input_asset_bridge_portal_address] +inner = "7" + +[output_asset_bridge_portal_address] +inner = "8" diff --git a/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/src/main.nr b/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/src/main.nr new file mode 100644 index 00000000000..42cc6d4ff3b --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/sha256_brillig_performance_regression/src/main.nr @@ -0,0 +1,104 @@ +// Performance regression extracted from an aztec protocol contract. + +unconstrained fn main( + input_asset_bridge_portal_address: EthAddress, + input_amount: Field, + uniswap_fee_tier: Field, + output_asset_bridge_portal_address: EthAddress, + minimum_output_amount: Field, + aztec_recipient: AztecAddress, + secret_hash_for_L1_to_l2_message: Field, + caller_on_L1: EthAddress, +) -> pub Field { + let mut hash_bytes = [0; 260]; // 8 fields of 32 bytes each + 4 bytes fn selector + let input_token_portal_bytes: [u8; 32] = + input_asset_bridge_portal_address.to_field().to_be_bytes(); + let in_amount_bytes: [u8; 32] = input_amount.to_be_bytes(); + let uniswap_fee_tier_bytes: [u8; 32] = uniswap_fee_tier.to_be_bytes(); + let output_token_portal_bytes: [u8; 32] = + output_asset_bridge_portal_address.to_field().to_be_bytes(); + let amount_out_min_bytes: [u8; 32] = minimum_output_amount.to_be_bytes(); + let aztec_recipient_bytes: [u8; 32] = aztec_recipient.to_field().to_be_bytes(); + let secret_hash_for_L1_to_l2_message_bytes: [u8; 32] = + secret_hash_for_L1_to_l2_message.to_be_bytes(); + let caller_on_L1_bytes: [u8; 32] = caller_on_L1.to_field().to_be_bytes(); + + // The purpose of including the following selector is to make the message unique to that specific call. Note that + // it has nothing to do with calling the function. + let selector = comptime { + std::hash::keccak256( + "swap_public(address,uint256,uint24,address,uint256,bytes32,bytes32,address)".as_bytes(), + 75, + ) + }; + + hash_bytes[0] = selector[0]; + hash_bytes[1] = selector[1]; + hash_bytes[2] = selector[2]; + hash_bytes[3] = selector[3]; + + for i in 0..32 { + hash_bytes[i + 4] = input_token_portal_bytes[i]; + hash_bytes[i + 36] = in_amount_bytes[i]; + hash_bytes[i + 68] = uniswap_fee_tier_bytes[i]; + hash_bytes[i + 100] = output_token_portal_bytes[i]; + hash_bytes[i + 132] = amount_out_min_bytes[i]; + hash_bytes[i + 164] = aztec_recipient_bytes[i]; + hash_bytes[i + 196] = secret_hash_for_L1_to_l2_message_bytes[i]; + hash_bytes[i + 228] = caller_on_L1_bytes[i]; + } + + let content_hash = sha256_to_field(hash_bytes); + content_hash +} + +// Convert a 32 byte array to a field element by truncating the final byte +pub fn field_from_bytes_32_trunc(bytes32: [u8; 32]) -> Field { + // Convert it to a field element + let mut v = 1; + let mut high = 0 as Field; + let mut low = 0 as Field; + + for i in 0..15 { + // covers bytes 16..30 (31 is truncated and ignored) + low = low + (bytes32[15 + 15 - i] as Field) * v; + v = v * 256; + // covers bytes 0..14 + high = high + (bytes32[14 - i] as Field) * v; + } + // covers byte 15 + low = low + (bytes32[15] as Field) * v; + + low + high * v +} + +pub fn sha256_to_field(bytes_to_hash: [u8; N]) -> Field { + let sha256_hashed = std::hash::sha256(bytes_to_hash); + let hash_in_a_field = field_from_bytes_32_trunc(sha256_hashed); + + hash_in_a_field +} + +pub trait ToField { + fn to_field(self) -> Field; +} + +pub struct EthAddress { + inner: Field, +} + +impl ToField for EthAddress { + fn to_field(self) -> Field { + self.inner + } +} + +pub struct AztecAddress { + pub inner: Field, +} + +impl ToField for AztecAddress { + fn to_field(self) -> Field { + self.inner + } +} diff --git a/noir/noir-repo/tooling/nargo_fmt/build.rs b/noir/noir-repo/tooling/nargo_fmt/build.rs index 47bb375f7d1..bd2db5f5b18 100644 --- a/noir/noir-repo/tooling/nargo_fmt/build.rs +++ b/noir/noir-repo/tooling/nargo_fmt/build.rs @@ -47,7 +47,10 @@ fn generate_formatter_tests(test_file: &mut File, test_data_dir: &Path) { .join("\n"); let output_source_path = outputs_dir.join(file_name).display().to_string(); - let output_source = std::fs::read_to_string(output_source_path.clone()).unwrap(); + let output_source = + std::fs::read_to_string(output_source_path.clone()).unwrap_or_else(|_| { + panic!("expected output source at {:?} was not found", &output_source_path) + }); write!( test_file, diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs index 73d9a61b3d4..b31da8a4101 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/trait_impl.rs @@ -10,6 +10,11 @@ use super::Formatter; impl<'a> Formatter<'a> { pub(super) fn format_trait_impl(&mut self, trait_impl: NoirTraitImpl) { + // skip synthetic trait impl's, e.g. generated from trait aliases + if trait_impl.is_synthetic { + return; + } + let has_where_clause = !trait_impl.where_clause.is_empty(); self.write_indentation(); diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/traits.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/traits.rs index 9a6b84c6537..1f192be471e 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/traits.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/traits.rs @@ -15,9 +15,18 @@ impl<'a> Formatter<'a> { self.write_identifier(noir_trait.name); self.format_generics(noir_trait.generics); + if noir_trait.is_alias { + self.write_space(); + self.write_token(Token::Assign); + } + if !noir_trait.bounds.is_empty() { self.skip_comments_and_whitespace(); - self.write_token(Token::Colon); + + if !noir_trait.is_alias { + self.write_token(Token::Colon); + } + self.write_space(); for (index, trait_bound) in noir_trait.bounds.into_iter().enumerate() { @@ -34,6 +43,12 @@ impl<'a> Formatter<'a> { self.format_where_clause(noir_trait.where_clause, true); } + // aliases have ';' in lieu of '{ items }' + if noir_trait.is_alias { + self.write_semicolon(); + return; + } + self.write_space(); self.write_left_brace(); if noir_trait.items.is_empty() { diff --git a/noir/noir-repo/tooling/nargo_fmt/tests/expected/trait_alias.nr b/noir/noir-repo/tooling/nargo_fmt/tests/expected/trait_alias.nr new file mode 100644 index 00000000000..926f3160279 --- /dev/null +++ b/noir/noir-repo/tooling/nargo_fmt/tests/expected/trait_alias.nr @@ -0,0 +1,86 @@ +trait Foo { + fn foo(self) -> Self; +} + +trait Baz = Foo; + +impl Foo for Field { + fn foo(self) -> Self { + self + } +} + +fn baz(x: T) -> T +where + T: Baz, +{ + x.foo() +} + +pub trait Foo_2 { + fn foo_2(self) -> Self; +} + +pub trait Bar_2 { + fn bar_2(self) -> Self; +} + +pub trait Baz_2 = Foo_2 + Bar_2; + +fn baz_2(x: T) -> T +where + T: Baz_2, +{ + x.foo_2().bar_2() +} + +impl Foo_2 for Field { + fn foo_2(self) -> Self { + self + 1 + } +} + +impl Bar_2 for Field { + fn bar_2(self) -> Self { + self + 2 + } +} + +trait Foo_3 { + fn foo_3(self) -> Self; +} + +trait Bar_3 { + fn bar_3(self) -> T; +} + +trait Baz_3 = Foo_3 + Bar_3; + +fn baz_3(x: T) -> U +where + T: Baz_3, +{ + x.foo_3().bar_3() +} + +impl Foo_3 for Field { + fn foo_3(self) -> Self { + self + 1 + } +} + +impl Bar_3 for Field { + fn bar_3(self) -> bool { + true + } +} + +fn main() { + let x: Field = 0; + let _ = baz(x); + + assert(0.foo_2().bar_2() == baz_2(0)); + + assert(0.foo_3().bar_3() == baz_3(0)); +} + diff --git a/noir/noir-repo/tooling/nargo_fmt/tests/input/trait_alias.nr b/noir/noir-repo/tooling/nargo_fmt/tests/input/trait_alias.nr new file mode 100644 index 00000000000..53ae756795b --- /dev/null +++ b/noir/noir-repo/tooling/nargo_fmt/tests/input/trait_alias.nr @@ -0,0 +1,78 @@ +trait Foo { + fn foo(self) -> Self; +} + +trait Baz = Foo; + +impl Foo for Field { + fn foo(self) -> Self { self } +} + +fn baz(x: T) -> T where T: Baz { + x.foo() +} + + +pub trait Foo_2 { + fn foo_2(self) -> Self; +} + +pub trait Bar_2 { + fn bar_2(self) -> Self; +} + +pub trait Baz_2 = Foo_2 + Bar_2; + +fn baz_2(x: T) -> T where T: Baz_2 { + x.foo_2().bar_2() +} + +impl Foo_2 for Field { + fn foo_2(self) -> Self { + self + 1 + } +} + +impl Bar_2 for Field { + fn bar_2(self) -> Self { + self + 2 + } +} + + +trait Foo_3 { + fn foo_3(self) -> Self; +} + +trait Bar_3 { + fn bar_3(self) -> T; +} + +trait Baz_3 = Foo_3 + Bar_3; + +fn baz_3(x: T) -> U where T: Baz_3 { + x.foo_3().bar_3() +} + +impl Foo_3 for Field { + fn foo_3(self) -> Self { + self + 1 + } +} + +impl Bar_3 for Field { + fn bar_3(self) -> bool { + true + } +} + + +fn main() { + let x: Field = 0; + let _ = baz(x); + + assert(0.foo_2().bar_2() == baz_2(0)); + + assert(0.foo_3().bar_3() == baz_3(0)); +} +