From 2bc0f6760c09a07cf66a8edb5420363a1032742b Mon Sep 17 00:00:00 2001 From: jfecher Date: Mon, 13 Jun 2022 08:54:32 -0400 Subject: [PATCH] Merge jf/ssa (#223) * Revert "Revert "Merge 'jf/ssa' to master (#217)" (#222)" This reverts commit 7d40c570859838c9e318458db95d044a7d201d47. * Add Opcode enum * Move passing tests to test_data directory * Remove extra printlns * Fix Gate Display impl * Fix Store instructions being deleted * Some code review changes * cargo fmt * Address more PR comments * Change Result instruction to get only 1 value * Use inline_map for result instructions * Port over new_cloned_instruction fix, move more passing examples to test directory * Re-disable ssa by default * Change replacement to Mark enum * Add get_current_value call * Code review changes * Don't propagate twice * Remove comment * Use Mark::ReplaceWith for function calls * Cargo fmt * Fix flipped jmp instructions * Remove assert * Minor renamings * Re-add assert, remove comment * Fix truncate behavior * Formatting changes * Fix frontend errors for 6_array * Fix 6_arrays * cargo fmt --- .rustfmt.toml | 2 +- crates/acir/src/circuit/gate.rs | 45 +- crates/nargo/src/cli/mod.rs | 14 +- crates/nargo/src/cli/prove_cmd.rs | 21 +- crates/nargo/src/cli/verify_cmd.rs | 10 +- crates/nargo/tests/prove_and_verify.rs | 3 +- .../nargo/tests/test_data/5_over/src/main.nr | 2 +- .../nargo/tests/test_data/6}/Nargo.toml | 0 .../nargo/tests/test_data}/6/Prover.toml | 0 .../nargo/tests/test_data}/6/Verifier.toml | 0 .../nargo/tests/test_data}/6/src/main.nr | 0 .../nargo/tests/test_data/6_array/src/main.nr | 9 +- .../nargo/tests/test_data/7}/Nargo.toml | 0 .../nargo/tests/test_data}/7/Prover.toml | 0 .../nargo/tests/test_data/7}/Verifier.toml | 0 .../nargo/tests/test_data}/7/src/main.nr | 0 .../tests/test_data}/assign_ex/Nargo.toml | 0 .../tests/test_data}/assign_ex/Prover.toml | 0 .../tests/test_data/assign_ex}/Verifier.toml | 0 .../tests/test_data}/assign_ex/src/main.nr | 0 .../tests/test_data}/bool_not/Nargo.toml | 0 .../tests/test_data}/bool_not/Prover.toml | 0 .../tests/test_data/bool_not}/Verifier.toml | 0 .../tests/test_data}/bool_not/src/main.nr | 0 .../nargo/tests/test_data}/bool_or/Nargo.toml | 0 .../tests/test_data}/bool_or/Prover.toml | 0 .../tests/test_data/bool_or}/Verifier.toml | 0 .../tests/test_data}/bool_or/src/main.nr | 0 .../test_data}/pedersen_check/Nargo.toml | 0 .../test_data}/pedersen_check/Prover.toml | 0 .../test_data/pedersen_check}/Verifier.toml | 0 .../test_data}/pedersen_check/src/main.nr | 0 .../nargo/tests/test_data}/pred_eq/Nargo.toml | 0 .../tests/test_data}/pred_eq/Prover.toml | 0 .../tests/test_data/pred_eq}/Verifier.toml | 0 .../tests/test_data}/pred_eq/src/main.nr | 0 .../nargo/tests/test_data/schnorr}/Nargo.toml | 0 .../tests/test_data/schnorr}/Prover.toml | 0 .../tests/test_data/schnorr}/Verifier.toml | 0 .../tests/test_data/schnorr}/src/main.nr | 0 .../nargo/tests/test_data/sha256}/Nargo.toml | 0 .../nargo/tests/test_data/sha256}/Prover.toml | 0 .../tests/test_data/sha256}/Verifier.toml | 0 .../nargo/tests/test_data/sha256}/src/main.nr | 0 .../nargo/tests/test_data}/tuples/Nargo.toml | 0 .../nargo/tests/test_data}/tuples/Prover.toml | 0 .../tests/test_data}/tuples/Verifier.toml | 0 .../nargo/tests/test_data}/tuples/src/main.nr | 0 crates/noir_field/src/generic_ark.rs | 9 + crates/noirc_driver/src/lib.rs | 4 +- crates/noirc_evaluator/Cargo.toml | 1 - crates/noirc_evaluator/src/lib.rs | 13 +- crates/noirc_evaluator/src/ssa/acir_gen.rs | 366 +++-- crates/noirc_evaluator/src/ssa/block.rs | 30 +- crates/noirc_evaluator/src/ssa/code_gen.rs | 201 ++- crates/noirc_evaluator/src/ssa/context.rs | 259 ++-- crates/noirc_evaluator/src/ssa/flatten.rs | 480 ++++--- crates/noirc_evaluator/src/ssa/function.rs | 39 +- crates/noirc_evaluator/src/ssa/integer.rs | 590 ++++---- crates/noirc_evaluator/src/ssa/mem.rs | 61 +- crates/noirc_evaluator/src/ssa/node.rs | 1181 +++++++++-------- crates/noirc_evaluator/src/ssa/optim.rs | 620 ++++----- crates/noirc_evaluator/src/ssa/ssa_form.rs | 22 +- crates/noirc_frontend/src/hir_def/expr.rs | 1 - crates/noirc_frontend/src/parser/parser.rs | 1 - examples/assign_ex/proofs/p.proof | 1 - 66 files changed, 1942 insertions(+), 2043 deletions(-) rename {examples/10 => crates/nargo/tests/test_data/6}/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/6/Prover.toml (100%) rename {examples => crates/nargo/tests/test_data}/6/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/6/src/main.nr (100%) rename {examples/5 => crates/nargo/tests/test_data/7}/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/7/Prover.toml (100%) rename {examples/10 => crates/nargo/tests/test_data/7}/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/7/src/main.nr (100%) rename {examples => crates/nargo/tests/test_data}/assign_ex/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/assign_ex/Prover.toml (100%) rename {examples/5 => crates/nargo/tests/test_data/assign_ex}/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/assign_ex/src/main.nr (100%) rename {examples => crates/nargo/tests/test_data}/bool_not/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/bool_not/Prover.toml (100%) rename {examples/7 => crates/nargo/tests/test_data/bool_not}/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/bool_not/src/main.nr (100%) rename {examples => crates/nargo/tests/test_data}/bool_or/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/bool_or/Prover.toml (100%) rename {examples/assign_ex => crates/nargo/tests/test_data/bool_or}/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/bool_or/src/main.nr (100%) rename {examples => crates/nargo/tests/test_data}/pedersen_check/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/pedersen_check/Prover.toml (100%) rename {examples/bool_not => crates/nargo/tests/test_data/pedersen_check}/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/pedersen_check/src/main.nr (100%) rename {examples => crates/nargo/tests/test_data}/pred_eq/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/pred_eq/Prover.toml (100%) rename {examples/bool_or => crates/nargo/tests/test_data/pred_eq}/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/pred_eq/src/main.nr (100%) rename {examples/6 => crates/nargo/tests/test_data/schnorr}/Nargo.toml (100%) rename {examples/10 => crates/nargo/tests/test_data/schnorr}/Prover.toml (100%) rename {examples/pedersen_check => crates/nargo/tests/test_data/schnorr}/Verifier.toml (100%) rename {examples/10 => crates/nargo/tests/test_data/schnorr}/src/main.nr (100%) rename {examples/7 => crates/nargo/tests/test_data/sha256}/Nargo.toml (100%) rename {examples/5 => crates/nargo/tests/test_data/sha256}/Prover.toml (100%) rename {examples/pred_eq => crates/nargo/tests/test_data/sha256}/Verifier.toml (100%) rename {examples/5 => crates/nargo/tests/test_data/sha256}/src/main.nr (100%) rename {examples => crates/nargo/tests/test_data}/tuples/Nargo.toml (100%) rename {examples => crates/nargo/tests/test_data}/tuples/Prover.toml (100%) rename {examples => crates/nargo/tests/test_data}/tuples/Verifier.toml (100%) rename {examples => crates/nargo/tests/test_data}/tuples/src/main.nr (100%) delete mode 100644 examples/assign_ex/proofs/p.proof diff --git a/.rustfmt.toml b/.rustfmt.toml index cb009953a74..c13d3e328d4 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,2 +1,2 @@ edition = "2018" -use_small_heuristics="Max" +use_small_heuristics = "Max" diff --git a/crates/acir/src/circuit/gate.rs b/crates/acir/src/circuit/gate.rs index 7bf1498fa51..ca8a50d61f2 100644 --- a/crates/acir/src/circuit/gate.rs +++ b/crates/acir/src/circuit/gate.rs @@ -44,69 +44,64 @@ impl Gate { impl std::fmt::Debug for Gate { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut result = String::new(); match self { Gate::Arithmetic(a) => { for i in &a.mul_terms { - result += - &format!("{:?}x{}*x{} + ", i.0, i.1.witness_index(), i.2.witness_index()); + write!(f, "{:?}x{}*x{} + ", i.0, i.1.witness_index(), i.2.witness_index())?; } for i in &a.linear_combinations { - result += &format!("{:?}x{} + ", i.0, i.1.witness_index()); + write!(f, "{:?}x{} + ", i.0, i.1.witness_index())?; } - result += &format!("{:?} = 0", a.q_c); + write!(f, "{:?} = 0", a.q_c) } Gate::Range(w, s) => { - result = format!("x{} is {} bits", w.witness_index(), s); + write!(f, "x{} is {} bits", w.witness_index(), s) } Gate::Directive(Directive::Invert { x, result: r }) => { - result = format!("x{}=1/x{}, or 0", r.witness_index(), x.witness_index()); + write!(f, "x{}=1/x{}, or 0", r.witness_index(), x.witness_index()) } Gate::Directive(Directive::Truncate { a, b, c: _c, bit_size }) => { - result = format!( + write!( + f, "Truncate: x{} is x{} truncated to {} bits", b.witness_index(), a.witness_index(), bit_size - ); + ) } Gate::Directive(Directive::Quotient { a, b, q, r }) => { - result = format!( + write!( + f, "Euclidian division: x{} = x{}*x{} + x{}", a.witness_index(), q.witness_index(), b.witness_index(), r.witness_index() - ); + ) } Gate::Directive(Directive::Oddrange { a, b, r, bit_size }) => { - result = format!( + write!( + f, "Oddrange: x{} = x{}*2^{} + x{}", a.witness_index(), b.witness_index(), bit_size, r.witness_index() - ); - } - Gate::And(g) => { - dbg!(&g); - } - Gate::Xor(g) => { - dbg!(&g); - } - Gate::GadgetCall(g) => { - dbg!(&g); + ) } + Gate::And(g) => write!(f, "{:?}", g), + Gate::Xor(g) => write!(f, "{:?}", g), + Gate::GadgetCall(g) => write!(f, "{:?}", g), Gate::Directive(Directive::Split { a, b, bit_size: _ }) => { - result = format!( + write!( + f, "Split: x{} into x{}...x{}", a.witness_index(), b.first().unwrap().witness_index(), b.last().unwrap().witness_index(), - ); + ) } } - write!(f, "{}", result) } } diff --git a/crates/nargo/src/cli/mod.rs b/crates/nargo/src/cli/mod.rs index 4086efa2030..0a061b0f3a5 100644 --- a/crates/nargo/src/cli/mod.rs +++ b/crates/nargo/src/cli/mod.rs @@ -48,7 +48,11 @@ pub fn start_cli() { App::new("prove") .about("Create proof for this program") .arg(Arg::with_name("proof_name").help("The name of the proof").required(true)) - .arg(Arg::with_name("interactive").help("pause execution").required(false)), + .arg( + Arg::with_name("show-ssa") + .long("show-ssa") + .help("Emit debug information for the intermediate SSA IR"), + ), ) .get_matches(); @@ -88,9 +92,11 @@ fn write_to_file(bytes: &[u8], path: &Path) -> String { } // helper function which tests noir programs by trying to generate a proof and verify it -pub fn prove_and_verify(proof_name: &str, prg_dir: &Path) -> bool { +pub fn prove_and_verify(proof_name: &str, prg_dir: &Path, show_ssa: bool) -> bool { let tmp_dir = TempDir::new("p_and_v_tests").unwrap(); + println!("prove_with_path(_, {})", show_ssa); let proof_path = - prove_cmd::prove_with_path(proof_name, prg_dir, &tmp_dir.into_path(), false).unwrap(); - verify_cmd::verify_with_path(prg_dir, &proof_path).unwrap() + prove_cmd::prove_with_path(proof_name, prg_dir, &tmp_dir.into_path(), show_ssa).unwrap(); + + verify_cmd::verify_with_path(prg_dir, &proof_path, show_ssa).unwrap() } diff --git a/crates/nargo/src/cli/prove_cmd.rs b/crates/nargo/src/cli/prove_cmd.rs index aafc577b287..d608b229e62 100644 --- a/crates/nargo/src/cli/prove_cmd.rs +++ b/crates/nargo/src/cli/prove_cmd.rs @@ -13,26 +13,21 @@ use crate::{errors::CliError, resolver::Resolver}; use super::{create_dir, write_to_file, PROOFS_DIR, PROOF_EXT, PROVER_INPUT_FILE}; pub(crate) fn run(args: ArgMatches) -> Result<(), CliError> { - let proof_name = args.subcommand_matches("prove").unwrap().value_of("proof_name").unwrap(); - let interactive = args.subcommand_matches("prove").unwrap().value_of("interactive"); - let mut is_interactive = false; - if let Some(int) = interactive { - if int == "i" { - is_interactive = true; - } - } - prove(proof_name, is_interactive) + let args = args.subcommand_matches("prove").unwrap(); + let proof_name = args.value_of("proof_name").unwrap(); + let show_ssa = args.is_present("show-ssa"); + prove(proof_name, show_ssa) } /// In Barretenberg, the proof system adds a zero witness in the first index, /// So when we add witness values, their index start from 1. const WITNESS_OFFSET: u32 = 1; -fn prove(proof_name: &str, interactive: bool) -> Result<(), CliError> { +fn prove(proof_name: &str, show_ssa: bool) -> Result<(), CliError> { let curr_dir = std::env::current_dir().unwrap(); let mut proof_path = PathBuf::new(); proof_path.push(PROOFS_DIR); - let result = prove_with_path(proof_name, curr_dir, proof_path, interactive); + let result = prove_with_path(proof_name, curr_dir, proof_path, show_ssa); match result { Ok(_) => Ok(()), Err(e) => Err(e), @@ -89,11 +84,11 @@ pub fn prove_with_path>( proof_name: &str, program_dir: P, proof_dir: P, - interactive: bool, + show_ssa: bool, ) -> Result { let driver = Resolver::resolve_root_config(program_dir.as_ref())?; let backend = crate::backends::ConcreteBackend; - let compiled_program = driver.into_compiled_program(backend.np_language(), interactive); + let compiled_program = driver.into_compiled_program(backend.np_language(), show_ssa); // Parse the initial witness values let witness_map = noirc_abi::input_parser::Format::Toml.parse(program_dir, PROVER_INPUT_FILE); diff --git a/crates/nargo/src/cli/verify_cmd.rs b/crates/nargo/src/cli/verify_cmd.rs index 9b0afd5071f..753c275512b 100644 --- a/crates/nargo/src/cli/verify_cmd.rs +++ b/crates/nargo/src/cli/verify_cmd.rs @@ -31,7 +31,7 @@ fn verify(proof_name: &str) -> Result { proof_path.push(PROOFS_DIR); proof_path.push(Path::new(proof_name)); proof_path.set_extension(PROOF_EXT); - verify_with_path(&curr_dir, &proof_path) + verify_with_path(&curr_dir, &proof_path, false) } fn process_abi_with_verifier_input( @@ -77,11 +77,15 @@ pub fn add_dummy_setpub_arr(abi: &mut Abi) { abi.parameters.push((RESERVED_PUBLIC_ARR.into(), dummy_arr)); } -pub fn verify_with_path>(program_dir: P, proof_path: P) -> Result { +pub fn verify_with_path>( + program_dir: P, + proof_path: P, + show_ssa: bool, +) -> Result { let driver = Resolver::resolve_root_config(program_dir.as_ref())?; let backend = crate::backends::ConcreteBackend; - let compiled_program = driver.into_compiled_program(backend.np_language(), false); + let compiled_program = driver.into_compiled_program(backend.np_language(), show_ssa); let mut public_abi = compiled_program.abi.clone().unwrap().public_abi(); add_dummy_setpub_arr(&mut public_abi); diff --git a/crates/nargo/tests/prove_and_verify.rs b/crates/nargo/tests/prove_and_verify.rs index c2a104c1c97..6a46e831ce5 100644 --- a/crates/nargo/tests/prove_and_verify.rs +++ b/crates/nargo/tests/prove_and_verify.rs @@ -46,7 +46,8 @@ mod tests { match test_name { Ok(str) => { if c.path().is_dir() && !conf_data["exclude"].contains(&str) { - let r = nargo::cli::prove_and_verify("pp", &c.path()); + println!("prove_and_verify(_, true)"); + let r = nargo::cli::prove_and_verify("pp", &c.path(), true); if conf_data["fail"].contains(&str) { assert!(!r, "{:?} should not succeed", c.file_name()); } else { diff --git a/crates/nargo/tests/test_data/5_over/src/main.nr b/crates/nargo/tests/test_data/5_over/src/main.nr index 3fa9fa9c304..7f71dc8d8e2 100644 --- a/crates/nargo/tests/test_data/5_over/src/main.nr +++ b/crates/nargo/tests/test_data/5_over/src/main.nr @@ -2,5 +2,5 @@ fn main(mut x : u32, y : u32) { x = x * x; - constrain (y) == x; + constrain y == x; } diff --git a/examples/10/Nargo.toml b/crates/nargo/tests/test_data/6/Nargo.toml similarity index 100% rename from examples/10/Nargo.toml rename to crates/nargo/tests/test_data/6/Nargo.toml diff --git a/examples/6/Prover.toml b/crates/nargo/tests/test_data/6/Prover.toml similarity index 100% rename from examples/6/Prover.toml rename to crates/nargo/tests/test_data/6/Prover.toml diff --git a/examples/6/Verifier.toml b/crates/nargo/tests/test_data/6/Verifier.toml similarity index 100% rename from examples/6/Verifier.toml rename to crates/nargo/tests/test_data/6/Verifier.toml diff --git a/examples/6/src/main.nr b/crates/nargo/tests/test_data/6/src/main.nr similarity index 100% rename from examples/6/src/main.nr rename to crates/nargo/tests/test_data/6/src/main.nr diff --git a/crates/nargo/tests/test_data/6_array/src/main.nr b/crates/nargo/tests/test_data/6_array/src/main.nr index 66e8f6626d8..2ee4a30caa0 100644 --- a/crates/nargo/tests/test_data/6_array/src/main.nr +++ b/crates/nargo/tests/test_data/6_array/src/main.nr @@ -1,13 +1,14 @@ //Basic tests for arrays -fn main(x : [5]u32 , y : [5]u32 , z : u32, t : u32) { - let c = 2301 as u32; +fn main(x: [5]u32, y: [5]u32, mut z: u32, t: u32) { + let mut c = (z-z+2301) as u32; + //t= t+x[0]-x[0]; z=y[4]; //Test 1: for i in 0..5 { c = z*z*y[i]; - z = z - c; + z = z - c; }; constrain (z==0); //y[4]=0, so c and z are always 0 @@ -41,4 +42,4 @@ fn main(x : [5]u32 , y : [5]u32 , z : u32, t : u32) { }; }; constrain (z ==11539); -} \ No newline at end of file +} diff --git a/examples/5/Nargo.toml b/crates/nargo/tests/test_data/7/Nargo.toml similarity index 100% rename from examples/5/Nargo.toml rename to crates/nargo/tests/test_data/7/Nargo.toml diff --git a/examples/7/Prover.toml b/crates/nargo/tests/test_data/7/Prover.toml similarity index 100% rename from examples/7/Prover.toml rename to crates/nargo/tests/test_data/7/Prover.toml diff --git a/examples/10/Verifier.toml b/crates/nargo/tests/test_data/7/Verifier.toml similarity index 100% rename from examples/10/Verifier.toml rename to crates/nargo/tests/test_data/7/Verifier.toml diff --git a/examples/7/src/main.nr b/crates/nargo/tests/test_data/7/src/main.nr similarity index 100% rename from examples/7/src/main.nr rename to crates/nargo/tests/test_data/7/src/main.nr diff --git a/examples/assign_ex/Nargo.toml b/crates/nargo/tests/test_data/assign_ex/Nargo.toml similarity index 100% rename from examples/assign_ex/Nargo.toml rename to crates/nargo/tests/test_data/assign_ex/Nargo.toml diff --git a/examples/assign_ex/Prover.toml b/crates/nargo/tests/test_data/assign_ex/Prover.toml similarity index 100% rename from examples/assign_ex/Prover.toml rename to crates/nargo/tests/test_data/assign_ex/Prover.toml diff --git a/examples/5/Verifier.toml b/crates/nargo/tests/test_data/assign_ex/Verifier.toml similarity index 100% rename from examples/5/Verifier.toml rename to crates/nargo/tests/test_data/assign_ex/Verifier.toml diff --git a/examples/assign_ex/src/main.nr b/crates/nargo/tests/test_data/assign_ex/src/main.nr similarity index 100% rename from examples/assign_ex/src/main.nr rename to crates/nargo/tests/test_data/assign_ex/src/main.nr diff --git a/examples/bool_not/Nargo.toml b/crates/nargo/tests/test_data/bool_not/Nargo.toml similarity index 100% rename from examples/bool_not/Nargo.toml rename to crates/nargo/tests/test_data/bool_not/Nargo.toml diff --git a/examples/bool_not/Prover.toml b/crates/nargo/tests/test_data/bool_not/Prover.toml similarity index 100% rename from examples/bool_not/Prover.toml rename to crates/nargo/tests/test_data/bool_not/Prover.toml diff --git a/examples/7/Verifier.toml b/crates/nargo/tests/test_data/bool_not/Verifier.toml similarity index 100% rename from examples/7/Verifier.toml rename to crates/nargo/tests/test_data/bool_not/Verifier.toml diff --git a/examples/bool_not/src/main.nr b/crates/nargo/tests/test_data/bool_not/src/main.nr similarity index 100% rename from examples/bool_not/src/main.nr rename to crates/nargo/tests/test_data/bool_not/src/main.nr diff --git a/examples/bool_or/Nargo.toml b/crates/nargo/tests/test_data/bool_or/Nargo.toml similarity index 100% rename from examples/bool_or/Nargo.toml rename to crates/nargo/tests/test_data/bool_or/Nargo.toml diff --git a/examples/bool_or/Prover.toml b/crates/nargo/tests/test_data/bool_or/Prover.toml similarity index 100% rename from examples/bool_or/Prover.toml rename to crates/nargo/tests/test_data/bool_or/Prover.toml diff --git a/examples/assign_ex/Verifier.toml b/crates/nargo/tests/test_data/bool_or/Verifier.toml similarity index 100% rename from examples/assign_ex/Verifier.toml rename to crates/nargo/tests/test_data/bool_or/Verifier.toml diff --git a/examples/bool_or/src/main.nr b/crates/nargo/tests/test_data/bool_or/src/main.nr similarity index 100% rename from examples/bool_or/src/main.nr rename to crates/nargo/tests/test_data/bool_or/src/main.nr diff --git a/examples/pedersen_check/Nargo.toml b/crates/nargo/tests/test_data/pedersen_check/Nargo.toml similarity index 100% rename from examples/pedersen_check/Nargo.toml rename to crates/nargo/tests/test_data/pedersen_check/Nargo.toml diff --git a/examples/pedersen_check/Prover.toml b/crates/nargo/tests/test_data/pedersen_check/Prover.toml similarity index 100% rename from examples/pedersen_check/Prover.toml rename to crates/nargo/tests/test_data/pedersen_check/Prover.toml diff --git a/examples/bool_not/Verifier.toml b/crates/nargo/tests/test_data/pedersen_check/Verifier.toml similarity index 100% rename from examples/bool_not/Verifier.toml rename to crates/nargo/tests/test_data/pedersen_check/Verifier.toml diff --git a/examples/pedersen_check/src/main.nr b/crates/nargo/tests/test_data/pedersen_check/src/main.nr similarity index 100% rename from examples/pedersen_check/src/main.nr rename to crates/nargo/tests/test_data/pedersen_check/src/main.nr diff --git a/examples/pred_eq/Nargo.toml b/crates/nargo/tests/test_data/pred_eq/Nargo.toml similarity index 100% rename from examples/pred_eq/Nargo.toml rename to crates/nargo/tests/test_data/pred_eq/Nargo.toml diff --git a/examples/pred_eq/Prover.toml b/crates/nargo/tests/test_data/pred_eq/Prover.toml similarity index 100% rename from examples/pred_eq/Prover.toml rename to crates/nargo/tests/test_data/pred_eq/Prover.toml diff --git a/examples/bool_or/Verifier.toml b/crates/nargo/tests/test_data/pred_eq/Verifier.toml similarity index 100% rename from examples/bool_or/Verifier.toml rename to crates/nargo/tests/test_data/pred_eq/Verifier.toml diff --git a/examples/pred_eq/src/main.nr b/crates/nargo/tests/test_data/pred_eq/src/main.nr similarity index 100% rename from examples/pred_eq/src/main.nr rename to crates/nargo/tests/test_data/pred_eq/src/main.nr diff --git a/examples/6/Nargo.toml b/crates/nargo/tests/test_data/schnorr/Nargo.toml similarity index 100% rename from examples/6/Nargo.toml rename to crates/nargo/tests/test_data/schnorr/Nargo.toml diff --git a/examples/10/Prover.toml b/crates/nargo/tests/test_data/schnorr/Prover.toml similarity index 100% rename from examples/10/Prover.toml rename to crates/nargo/tests/test_data/schnorr/Prover.toml diff --git a/examples/pedersen_check/Verifier.toml b/crates/nargo/tests/test_data/schnorr/Verifier.toml similarity index 100% rename from examples/pedersen_check/Verifier.toml rename to crates/nargo/tests/test_data/schnorr/Verifier.toml diff --git a/examples/10/src/main.nr b/crates/nargo/tests/test_data/schnorr/src/main.nr similarity index 100% rename from examples/10/src/main.nr rename to crates/nargo/tests/test_data/schnorr/src/main.nr diff --git a/examples/7/Nargo.toml b/crates/nargo/tests/test_data/sha256/Nargo.toml similarity index 100% rename from examples/7/Nargo.toml rename to crates/nargo/tests/test_data/sha256/Nargo.toml diff --git a/examples/5/Prover.toml b/crates/nargo/tests/test_data/sha256/Prover.toml similarity index 100% rename from examples/5/Prover.toml rename to crates/nargo/tests/test_data/sha256/Prover.toml diff --git a/examples/pred_eq/Verifier.toml b/crates/nargo/tests/test_data/sha256/Verifier.toml similarity index 100% rename from examples/pred_eq/Verifier.toml rename to crates/nargo/tests/test_data/sha256/Verifier.toml diff --git a/examples/5/src/main.nr b/crates/nargo/tests/test_data/sha256/src/main.nr similarity index 100% rename from examples/5/src/main.nr rename to crates/nargo/tests/test_data/sha256/src/main.nr diff --git a/examples/tuples/Nargo.toml b/crates/nargo/tests/test_data/tuples/Nargo.toml similarity index 100% rename from examples/tuples/Nargo.toml rename to crates/nargo/tests/test_data/tuples/Nargo.toml diff --git a/examples/tuples/Prover.toml b/crates/nargo/tests/test_data/tuples/Prover.toml similarity index 100% rename from examples/tuples/Prover.toml rename to crates/nargo/tests/test_data/tuples/Prover.toml diff --git a/examples/tuples/Verifier.toml b/crates/nargo/tests/test_data/tuples/Verifier.toml similarity index 100% rename from examples/tuples/Verifier.toml rename to crates/nargo/tests/test_data/tuples/Verifier.toml diff --git a/examples/tuples/src/main.nr b/crates/nargo/tests/test_data/tuples/src/main.nr similarity index 100% rename from examples/tuples/src/main.nr rename to crates/nargo/tests/test_data/tuples/src/main.nr diff --git a/crates/noir_field/src/generic_ark.rs b/crates/noir_field/src/generic_ark.rs index c20e5ec5335..2d7ba43504a 100644 --- a/crates/noir_field/src/generic_ark.rs +++ b/crates/noir_field/src/generic_ark.rs @@ -181,6 +181,11 @@ impl FieldElement { let bytes = self.to_bytes(); u128::from_be_bytes(bytes[16..32].try_into().unwrap()) } + + pub fn try_into_u128(self) -> Option { + self.fits_in_u128().then(|| self.to_u128()) + } + /// Computes the inverse or returns zero if the inverse does not exist /// Before using this FieldElement, please ensure that this behaviour is necessary pub fn inverse(&self) -> FieldElement { @@ -188,6 +193,10 @@ impl FieldElement { FieldElement(inv) } + pub fn try_inverse(mut self) -> Option { + self.0.inverse_in_place().map(|f| FieldElement(*f)) + } + // XXX: This method is used while this field element // implementation is not generic. pub fn into_repr(self) -> F { diff --git a/crates/noirc_driver/src/lib.rs b/crates/noirc_driver/src/lib.rs index 75fcf523590..6d722e10fd6 100644 --- a/crates/noirc_driver/src/lib.rs +++ b/crates/noirc_driver/src/lib.rs @@ -139,7 +139,7 @@ impl Driver { pub fn into_compiled_program( mut self, np_language: acvm::Language, - interactive: bool, + show_ssa: bool, ) -> CompiledProgram { self.build(); // First find the local crate @@ -166,7 +166,7 @@ impl Driver { let evaluator = Evaluator::new(main_function, &self.context); // Compile Program - let circuit = match evaluator.compile(np_language, interactive) { + let circuit = match evaluator.compile(np_language, show_ssa) { Ok(circuit) => circuit, Err(err) => { // The FileId here will be the file id of the file with the main file diff --git a/crates/noirc_evaluator/Cargo.toml b/crates/noirc_evaluator/Cargo.toml index 625319a9641..293aee5ac69 100644 --- a/crates/noirc_evaluator/Cargo.toml +++ b/crates/noirc_evaluator/Cargo.toml @@ -16,4 +16,3 @@ lazy_static = "1.4.0" thiserror = "1.0.21" num-bigint = "0.4" num-traits = "0.2.8" -if_debug = "0.1.0" diff --git a/crates/noirc_evaluator/src/lib.rs b/crates/noirc_evaluator/src/lib.rs index acc264fa9bb..bd6b5ae1f58 100644 --- a/crates/noirc_evaluator/src/lib.rs +++ b/crates/noirc_evaluator/src/lib.rs @@ -96,14 +96,14 @@ impl<'a> Evaluator<'a> { pub fn compile( mut self, np_language: Language, - interactive: bool, + enable_logging: bool, ) -> Result { // create a new environment for the main context let mut env = Environment::new(FuncContext::Main); // First evaluate the main function - if interactive { - self.evaluate_main_alt(&mut env, interactive)?; + if enable_logging { + self.evaluate_main_alt(&mut env, enable_logging)?; } else { self.evaluate_main(&mut env)?; } @@ -173,9 +173,6 @@ impl<'a> Evaluator<'a> { HirBinaryOpKind::Shr | HirBinaryOpKind::Shl => Err(RuntimeErrorKind::Unimplemented( "Bit shift operations are not currently implemented.".to_owned(), )), - HirBinaryOpKind::MemberAccess => { - todo!("Member access for structs is unimplemented in the noir backend") - } } .map_err(|kind| kind.add_span(op.span)) } @@ -208,7 +205,7 @@ impl<'a> Evaluator<'a> { pub fn evaluate_main_alt( &mut self, env: &mut Environment, - interactive: bool, + enable_logging: bool, ) -> Result<(), RuntimeError> { let mut igen = IRGenerator::new(self.context); self.parse_abi_alt(env, &mut igen)?; @@ -218,7 +215,7 @@ impl<'a> Evaluator<'a> { ssa::code_gen::evaluate_main(&mut igen, env, main_func_body)?; //Generates ACIR representation: - igen.context.ir_to_acir(self, interactive)?; + igen.context.ir_to_acir(self, enable_logging)?; Ok(()) } diff --git a/crates/noirc_evaluator/src/ssa/acir_gen.rs b/crates/noirc_evaluator/src/ssa/acir_gen.rs index 4ef9b5d559c..31364f21786 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen.rs @@ -1,10 +1,9 @@ -use super::mem::{MemArray, Memory}; -use super::node::{Instruction, Operation}; +use super::mem::{ArrayId, MemArray, Memory}; +use super::node::{BinaryOp, ConstrainOp, Instruction, ObjectType, Operation}; +use acvm::acir::OPCODE; use acvm::FieldElement; use super::node::NodeId; -use acvm::acir::opcode::{InputSize, OutputSize}; -use acvm::acir::OPCODE; use num_traits::{One, Zero}; use std::cmp::Ordering; @@ -20,13 +19,12 @@ use crate::RuntimeErrorKind; use acvm::acir::circuit::gate::{Directive, GadgetCall, GadgetInput}; use acvm::acir::native_types::{Arithmetic, Linear, Witness}; use num_bigint::BigUint; -use std::convert::TryInto; #[derive(Default)] pub struct Acir { pub arith_cache: HashMap, pub memory_map: HashMap, //maps memory adress to expression - pub memory_witness: HashMap>, //map arrays to their witness...temporary + pub memory_witness: HashMap>, //map arrays to their witness...temporary } #[derive(Default, Clone, Debug)] @@ -36,6 +34,7 @@ pub struct InternalVar { witness: Option, id: Option, } + impl InternalVar { pub fn is_equal(&self, b: &InternalVar) -> bool { (self.id.is_some() && self.id == b.id) @@ -54,7 +53,7 @@ impl InternalVar { None } - pub fn get_or_generate_witness(&mut self, evaluator: &mut Evaluator) -> Witness { + pub fn get_or_generate_witness(&self, evaluator: &mut Evaluator) -> Witness { self.witness.unwrap_or_else(|| generate_witness(self, evaluator)) } } @@ -134,17 +133,90 @@ impl Acir { evaluator: &mut Evaluator, ctx: &SsaContext, ) { - if ins.operator == Operation::Nop { + if ins.operation == Operation::Nop { return; } - let l_c = self.substitute(ins.lhs, evaluator, ctx); - let mut r_c = self.substitute(ins.rhs, evaluator, ctx); - let mut output = match ins.operator { - Operation::Add | Operation::SafeAdd => { + + let mut output = match &ins.operation { + Operation::Binary(binary) => self.evaluate_binary(binary, ins.res_type, evaluator, ctx), + Operation::Not(_) => todo!(), + Operation::Cast(value) => self.substitute(*value, evaluator, ctx), + i @ Operation::Jne(..) + | i @ Operation::Jeq(..) + | i @ Operation::Jmp(_) + | i @ Operation::Phi { .. } + | i @ Operation::Result { .. } => { + unreachable!("Invalid instruction: {:?}", i); + } + Operation::Truncate { value, bit_size, max_bit_size } => { + let value = self.substitute(*value, evaluator, ctx); + evaluate_truncate(value, *bit_size, *max_bit_size, evaluator) + } + Operation::Intrinsic(opcode, args) => { + let v = self.evaluate_opcode(ins.id, *opcode, args, ins.res_type, ctx, evaluator); + InternalVar::from(v) + } + Operation::Call(..) => unreachable!("call instruction should have been inlined"), + Operation::Return(_) => todo!(), //return from main + Operation::Nop => InternalVar::default(), + Operation::Load { array_id, index } => { + //retrieves the value from the map if address is known at compile time: + //address = l_c and should be constant + let index = self.substitute(*index, evaluator, ctx); + if let Some(index) = index.to_const() { + let address = mem::Memory::as_u32(index); + if self.memory_map.contains_key(&address) { + InternalVar::from(self.memory_map[&address].expression.clone()) + } else { + //if not found, then it must be a witness (else it is non-initialised memory) + let mem_array = &ctx.mem[*array_id]; + let index = (address - mem_array.adr) as usize; + if mem_array.values.len() > index { + mem_array.values[index].clone() + } else { + InternalVar::from(self.memory_witness[array_id][index]) + } + } + } else { + todo!("dynamic arrays are not implemented yet"); + } + } + + Operation::Store { array_id: _, index, value } => { + //maps the address to the rhs if address is known at compile time + let index = self.substitute(*index, evaluator, ctx); + let value = self.substitute(*value, evaluator, ctx); + + if let Some(index) = index.to_const() { + let address = mem::Memory::as_u32(index); + self.memory_map.insert(address, value); + //we do not generate constraint, so no output. + InternalVar::default() + } else { + todo!("dynamic arrays are not implemented yet"); + } + } + }; + output.id = Some(ins.id); + self.arith_cache.insert(ins.id, output); + } + + fn evaluate_binary( + &mut self, + binary: &node::Binary, + res_type: ObjectType, + evaluator: &mut Evaluator, + ctx: &SsaContext, + ) -> InternalVar { + let l_c = self.substitute(binary.lhs, evaluator, ctx); + let r_c = self.substitute(binary.rhs, evaluator, ctx); + + match &binary.operator { + BinaryOp::Add | BinaryOp::SafeAdd => { InternalVar::from(add(&l_c.expression, FieldElement::one(), &r_c.expression)) } - Operation::Sub | Operation::SafeSub => { - if ins.res_type == node::ObjectType::NativeField { + BinaryOp::Sub { max_rhs_value } | BinaryOp::SafeSub { max_rhs_value } => { + if res_type == node::ObjectType::NativeField { InternalVar::from(subtract( &l_c.expression, FieldElement::one(), @@ -153,10 +225,10 @@ impl Acir { } else { //we need the type of rhs and its max value, then: //lhs-rhs+k*2^bit_size where k=ceil(max_value/2^bit_size) - let bit_size = ctx[ins.rhs].size_in_bits(); + let bit_size = ctx[binary.rhs].size_in_bits(); let r_big = BigUint::one() << bit_size; - let mut k = &ins.max_value / &r_big; - if &ins.max_value % &r_big != BigUint::zero() { + let mut k = max_rhs_value / &r_big; + if max_rhs_value % &r_big != BigUint::zero() { k = &k + BigUint::one(); } k = &k * r_big; @@ -167,7 +239,7 @@ impl Acir { let mut sub_var = sub_expr.into(); //TODO: uses interval analysis for more precise check if let Some(lhs_const) = l_c.to_const() { - if ins.max_value <= BigUint::from_bytes_be(&lhs_const.to_bytes()) { + if max_rhs_value <= &BigUint::from_bytes_be(&lhs_const.to_bytes()) { sub_var = InternalVar::from(subtract( &l_c.expression, FieldElement::one(), @@ -178,36 +250,36 @@ impl Acir { sub_var } } - Operation::Mul | Operation::SafeMul => { + BinaryOp::Mul | BinaryOp::SafeMul => { InternalVar::from(evaluate_mul(&l_c, &r_c, evaluator)) } - Operation::Udiv => { + BinaryOp::Udiv => { let (q_wit, _) = evaluate_udiv(&l_c, &r_c, evaluator); InternalVar::from(q_wit) } - Operation::Sdiv => InternalVar::from(evaluate_sdiv(&l_c, &r_c, evaluator).0), - Operation::Urem => { + BinaryOp::Sdiv => InternalVar::from(evaluate_sdiv(&l_c, &r_c, evaluator).0), + BinaryOp::Urem => { let (_, r_wit) = evaluate_udiv(&l_c, &r_c, evaluator); InternalVar::from(r_wit) } - Operation::Srem => InternalVar::from(evaluate_sdiv(&l_c, &r_c, evaluator).1), - Operation::Div => InternalVar::from(mul( + BinaryOp::Srem => InternalVar::from(evaluate_sdiv(&l_c, &r_c, evaluator).1), + BinaryOp::Div => InternalVar::from(mul( &l_c.expression, - &from_witness(evaluate_inverse(&mut r_c, evaluator)), + &from_witness(evaluate_inverse(r_c, evaluator)), )), - Operation::Eq => { - InternalVar::from(self.evaluate_eq(ins.lhs, ins.rhs, &l_c, &r_c, ctx, evaluator)) - } - Operation::Ne => { - InternalVar::from(self.evaluate_neq(ins.lhs, ins.rhs, &l_c, &r_c, ctx, evaluator)) + BinaryOp::Eq => InternalVar::from( + self.evaluate_eq(binary.lhs, binary.rhs, &l_c, &r_c, ctx, evaluator), + ), + BinaryOp::Ne => InternalVar::from( + self.evaluate_neq(binary.lhs, binary.rhs, &l_c, &r_c, ctx, evaluator), + ), + BinaryOp::Ult => { + let size = ctx[binary.lhs].size_in_bits(); + evaluate_cmp(&l_c, &r_c, size, false, evaluator).into() } - Operation::Ugt => { - let s = ctx[ins.lhs].size_in_bits(); - evaluate_cmp(&r_c, &l_c, s, false, evaluator).into() - } - Operation::Uge => { - let s = ctx[ins.lhs].size_in_bits(); - let w = evaluate_cmp(&l_c, &r_c, s, false, evaluator); + BinaryOp::Ule => { + let size = ctx[binary.lhs].size_in_bits(); + let w = evaluate_cmp(&r_c, &l_c, size, false, evaluator); Arithmetic { mul_terms: Vec::new(), linear_combinations: vec![(-FieldElement::one(), w)], @@ -215,40 +287,12 @@ impl Acir { } .into() } - Operation::Ult => { - let s = ctx[ins.lhs].size_in_bits(); - evaluate_cmp(&l_c, &r_c, s, false, evaluator).into() - } - Operation::Ule => { - let s = ctx[ins.lhs].size_in_bits(); - let w = evaluate_cmp(&r_c, &l_c, s, false, evaluator); - Arithmetic { - mul_terms: Vec::new(), - linear_combinations: vec![(-FieldElement::one(), w)], - q_c: FieldElement::one(), - } - .into() - } - Operation::Sgt => { - let s = ctx[ins.lhs].size_in_bits(); - evaluate_cmp(&r_c, &l_c, s, true, evaluator).into() - } - Operation::Sge => { - let s = ctx[ins.lhs].size_in_bits(); - let w = evaluate_cmp(&l_c, &r_c, s, true, evaluator); - Arithmetic { - mul_terms: Vec::new(), - linear_combinations: vec![(-FieldElement::one(), w)], - q_c: FieldElement::one(), - } - .into() - } - Operation::Slt => { - let s = ctx[ins.lhs].size_in_bits(); + BinaryOp::Slt => { + let s = ctx[binary.lhs].size_in_bits(); evaluate_cmp(&l_c, &r_c, s, true, evaluator).into() } - Operation::Sle => { - let s = ctx[ins.lhs].size_in_bits(); + BinaryOp::Sle => { + let s = ctx[binary.lhs].size_in_bits(); let w = evaluate_cmp(&r_c, &l_c, s, true, evaluator); Arithmetic { mul_terms: Vec::new(), @@ -257,88 +301,26 @@ impl Acir { } .into() } - Operation::Lt => todo!(), - Operation::Gte | Operation::Lte | Operation::Gt => unreachable!(), - Operation::Shl | Operation::Shr => unreachable!(), - Operation::And => { - InternalVar::from(evaluate_and(l_c, r_c, ins.res_type.bits(), evaluator)) - } - Operation::Not => todo!(), - Operation::Or => { - InternalVar::from(evaluate_or(l_c, r_c, ins.res_type.bits(), evaluator)) - } - Operation::Xor => { - InternalVar::from(evaluate_xor(l_c, r_c, ins.res_type.bits(), evaluator)) - } - Operation::Cast => l_c.clone(), - Operation::Ass - | Operation::Jne - | Operation::Jeq - | Operation::Jmp - | Operation::Phi - | Operation::Res => { - unreachable!("invalid instruction"); + BinaryOp::Lt => todo!(), + BinaryOp::Lte => { + let size = ctx[binary.lhs].size_in_bits(); + // TODO: Should this be signed? + evaluate_cmp(&l_c, &r_c, size, false, evaluator).into() } - Operation::Trunc => { - assert!(is_const(&r_c.expression)); - evaluate_truncate( - l_c, - r_c.expression.q_c.to_u128().try_into().unwrap(), - ins.bit_size, - evaluator, - ) - } - Operation::Intrinsic(opcode) => { - InternalVar::from(self.evaluate_opcode(ins, opcode, ctx, evaluator)) - } - Operation::Call(_) => unreachable!("call instruction should have been inlined"), - Operation::Ret => todo!(), //return from main - Operation::Nop => InternalVar::default(), - Operation::Constrain(op) => match op { - node::ConstrainOp::Eq => { - InternalVar::from(self.equalize(ins.lhs, ins.rhs, &l_c, &r_c, ctx, evaluator)) - } - node::ConstrainOp::Neq => { - InternalVar::from(self.distinct(ins.lhs, ins.rhs, &l_c, &r_c, ctx, evaluator)) - } + BinaryOp::And => InternalVar::from(evaluate_and(l_c, r_c, res_type.bits(), evaluator)), + BinaryOp::Or => InternalVar::from(evaluate_or(l_c, r_c, res_type.bits(), evaluator)), + BinaryOp::Xor => InternalVar::from(evaluate_xor(l_c, r_c, res_type.bits(), evaluator)), + BinaryOp::Constrain(op) => match op { + ConstrainOp::Eq => InternalVar::from( + self.equalize(binary.lhs, binary.rhs, &l_c, &r_c, ctx, evaluator), + ), + ConstrainOp::Neq => InternalVar::from( + self.distinct(binary.lhs, binary.rhs, &l_c, &r_c, ctx, evaluator), + ), }, - Operation::Load(array_idx) => { - //retrieves the value from the map if address is known at compile time: - //address = l_c and should be constant - if let Some(val) = l_c.to_const() { - let address = mem::Memory::as_u32(val); - if self.memory_map.contains_key(&address) { - InternalVar::from(self.memory_map[&address].expression.clone()) - } else { - //if not found, then it must be a witness (else it is non-initialised memory) - let array = &ctx.mem.arrays[array_idx as usize]; - let index = (address - array.adr) as usize; - if array.values.len() > index { - array.values[index].clone() - } else { - InternalVar::from(self.memory_witness[&array_idx][index]) - } - } - } else { - todo!("dynamic arrays are not implemented yet"); - } - } - - Operation::Store(_) => { - //maps the address to the rhs if address is known at compile time - if let Some(val) = r_c.to_const() { - let address = mem::Memory::as_u32(val); - self.memory_map.insert(address, l_c); - dbg!(&self.memory_map); - //we do not generate constraint, so no output. - InternalVar::default() - } else { - todo!("dynamic arrays are not implemented yet"); - } - } - }; - output.id = Some(ins.id); - self.arith_cache.insert(ins.id, output); + BinaryOp::Shl | BinaryOp::Shr => unreachable!(), + i @ BinaryOp::Assign => unreachable!("Invalid Instruction: {:?}", i), + } } pub fn print_circuit(gates: &[Gate]) { @@ -352,7 +334,6 @@ impl Acir { pub fn load_array( &mut self, array: &MemArray, - array_index: u32, create_witness: bool, evaluator: &mut Evaluator, ) -> Vec { @@ -366,7 +347,7 @@ impl Acir { self.memory_map.get_mut(&address).unwrap().witness = Some(w); } self.memory_map[&address].clone() - } else if let Some(memory) = self.memory_witness.get(&array_index) { + } else if let Some(memory) = self.memory_witness.get(&array.id) { let w = memory[i as usize]; w.into() } else { @@ -386,12 +367,11 @@ impl Acir { evaluator: &mut Evaluator, ) -> Arithmetic { if let (Some(a), Some(b)) = (Memory::deref(ctx, lhs), Memory::deref(ctx, rhs)) { - let array_a = &ctx.mem.arrays[a as usize]; - let array_b = &ctx.mem.arrays[b as usize]; + let array_a = &ctx.mem[a]; + let array_b = &ctx.mem[b]; if array_a.len == array_b.len { - let mut x = - InternalVar::from(self.zero_eq_array_sum(array_a, a, array_b, b, evaluator)); + let mut x = InternalVar::from(self.zero_eq_array_sum(array_a, array_b, evaluator)); x.witness = Some(generate_witness(&x, evaluator)); from_witness(evaluate_zero_equality(&x, evaluator)) } else { @@ -441,16 +421,16 @@ impl Acir { evaluator: &mut Evaluator, ) -> Arithmetic { if let (Some(a), Some(b)) = (Memory::deref(ctx, lhs), Memory::deref(ctx, rhs)) { - let array_a = &ctx.mem.arrays[a as usize]; - let array_b = &ctx.mem.arrays[b as usize]; + let array_a = &ctx.mem[a]; + let array_b = &ctx.mem[b]; //If length are different, then the arrays are different if array_a.len == array_b.len { - let sum = self.zero_eq_array_sum(array_a, a, array_b, b, evaluator); - evaluate_inverse(&mut InternalVar::from(sum), evaluator); + let sum = self.zero_eq_array_sum(array_a, array_b, evaluator); + evaluate_inverse(InternalVar::from(sum), evaluator); } } else { let diff = subtract(&l_c.expression, FieldElement::one(), &r_c.expression); - evaluate_inverse(&mut InternalVar::from(diff), evaluator); + evaluate_inverse(InternalVar::from(diff), evaluator); } Arithmetic::default() } @@ -466,8 +446,8 @@ impl Acir { evaluator: &mut Evaluator, ) -> Arithmetic { if let (Some(a), Some(b)) = (Memory::deref(ctx, lhs), Memory::deref(ctx, rhs)) { - let a_values = self.load_array(&ctx.mem.arrays[a as usize], a, false, evaluator); - let b_values = self.load_array(&ctx.mem.arrays[b as usize], b, false, evaluator); + let a_values = self.load_array(&ctx.mem[a], false, evaluator); + let b_values = self.load_array(&ctx.mem[b], false, evaluator); assert!(a_values.len() == b_values.len()); for (a_iter, b_iter) in a_values.into_iter().zip(b_values) { let array_diff = @@ -478,7 +458,7 @@ impl Acir { } else { let output = add(&l_c.expression, FieldElement::from(-1_i128), &r_c.expression); if is_const(&output) { - assert!(output.q_c == FieldElement::zero()); + assert_eq!(output.q_c, FieldElement::zero()); } else { evaluator.gates.push(Gate::Arithmetic(output.clone())); } @@ -491,15 +471,13 @@ impl Acir { fn zero_eq_array_sum( &mut self, a: &MemArray, - a_idx: u32, b: &MemArray, - b_idx: u32, evaluator: &mut Evaluator, ) -> Arithmetic { let mut sum = Arithmetic::default(); - let a_values = self.load_array(a, a_idx, false, evaluator); - let b_values = self.load_array(b, b_idx, false, evaluator); + let a_values = self.load_array(a, false, evaluator); + let b_values = self.load_array(b, false, evaluator); for (a_iter, b_iter) in a_values.into_iter().zip(b_values) { let diff_expr = subtract(&a_iter.expression, FieldElement::one(), &b_iter.expression); @@ -541,7 +519,7 @@ impl Acir { node::NodeObj::Obj(v) => { match l_obj.get_type() { node::ObjectType::Pointer(a) => { - let array = &cfg.mem.arrays[a as usize]; + let array = &cfg.mem[a]; let num_bits = array.element_type.bits(); for i in 0..array.len { let address = array.adr + i; @@ -584,44 +562,32 @@ impl Acir { pub fn evaluate_opcode( &mut self, - ins: &Instruction, + instruction_id: NodeId, opcode: OPCODE, + args: &[NodeId], + res_type: ObjectType, cfg: &SsaContext, evaluator: &mut Evaluator, ) -> Arithmetic { - let inputs; - let outputs; - let signature = opcode.definition(); + if opcode == OPCODE::ToBits { + todo!(); + } + let outputs; match opcode { OPCODE::ToBits => { - let bit_size = cfg.get_as_constant(ins.rhs).unwrap().to_u128() as u32; - let l_c = self.substitute(ins.lhs, evaluator, cfg); + let bit_size = cfg.get_as_constant(args[1]).unwrap().to_u128() as u32; + let l_c = self.substitute(args[0], evaluator, cfg); outputs = split(&l_c, bit_size, evaluator); - if let node::ObjectType::Pointer(a) = ins.res_type { + if let node::ObjectType::Pointer(a) = res_type { self.memory_witness.insert(a, outputs.clone()); //TODO can we avoid the clone? } } _ => { - match (signature.input_size, signature.output_size) { - (InputSize::Variable, OutputSize(y)) => { - inputs = self.prepare_inputs(&ins.get_arguments(), cfg, evaluator); - outputs = self.prepare_outputs(ins.id, y as u32, cfg, evaluator); - } - (InputSize::Fixed(x), OutputSize(y)) => match x { - 1 => { - inputs = self.prepare_inputs(&[ins.lhs], cfg, evaluator); - outputs = self.prepare_outputs(ins.id, y as u32, cfg, evaluator); - } - 2 => { - inputs = self.prepare_inputs(&[ins.lhs, ins.rhs], cfg, evaluator); - outputs = self.prepare_outputs(ins.id, y as u32, cfg, evaluator); - } - _ => { - todo!(); - } - }, - } + let inputs = self.prepare_inputs(args, cfg, evaluator); + let output_count = opcode.definition().output_size.0 as u32; + outputs = self.prepare_outputs(instruction_id, output_count, cfg, evaluator); + let call_gate = GadgetCall { name: opcode, inputs, //witness + bit size @@ -632,10 +598,11 @@ impl Acir { } if outputs.len() == 1 { - return from_witness(outputs[0]); + from_witness(outputs[0]) + } else { + //if there are more than one witness returned, the result is inside ins.res_type as a pointer to an array + Arithmetic::default() } - //if there are more than one witness returned, the result is inside ins.res_type as a pointer to an array - Arithmetic::default() } pub fn prepare_outputs( @@ -856,7 +823,8 @@ pub fn evaluate_truncate( max_bits: u32, evaluator: &mut Evaluator, ) -> InternalVar { - assert!(max_bits > rhs); + assert!(max_bits > rhs, "max_bits = {}, rhs = {}", max_bits, rhs); + //0. Check for constant expression. This can happen through arithmetic simplifications if let Some(a_c) = lhs.to_const() { let mut a_big = BigUint::from_bytes_be(&a_c.to_bytes()); @@ -1014,7 +982,7 @@ pub fn evaluate_zero_equality(x: &InternalVar, evaluator: &mut Evaluator) -> Wit } /// Creates a new witness and constrains it to be the inverse of x -pub fn evaluate_inverse(x: &mut InternalVar, evaluator: &mut Evaluator) -> Witness { +fn evaluate_inverse(x: InternalVar, evaluator: &mut Evaluator) -> Witness { // Create a fresh witness - n.b we could check if x is constant or not let inverse_witness = evaluator.add_witness_to_cs(); let inverse_expr = from_witness(inverse_witness); diff --git a/crates/noirc_evaluator/src/ssa/block.rs b/crates/noirc_evaluator/src/ssa/block.rs index 00f2a98b5a3..ed48965959d 100644 --- a/crates/noirc_evaluator/src/ssa/block.rs +++ b/crates/noirc_evaluator/src/ssa/block.rs @@ -65,27 +65,13 @@ impl BasicBlock { self.kind == BlockType::ForJoin } - pub fn get_result_instruction(&self, call_id: NodeId, ctx: &SsaContext) -> Option { - self.instructions.iter().copied().find(|i| match ctx[*i] { - node::NodeObj::Instr(node::Instruction { - operator: node::Operation::Res, lhs, .. - }) => lhs == call_id, - _ => false, - }) - } - //Create the first block for a CFG pub fn create_cfg(ctx: &mut SsaContext) -> BlockId { let root_block = BasicBlock::new(BlockId::dummy(), BlockType::Normal); let root_block = ctx.insert_block(root_block); let root_id = root_block.id; ctx.current_block = root_id; - ctx.new_instruction( - NodeId::dummy(), - NodeId::dummy(), - node::Operation::Nop, - node::ObjectType::NotAnObject, - ); + ctx.new_instruction(node::Operation::Nop, node::ObjectType::NotAnObject); root_id } } @@ -109,12 +95,7 @@ pub fn new_sealed_block(ctx: &mut SsaContext, kind: BlockType) -> BlockId { let cb = ctx.get_current_block_mut(); cb.left = Some(new_id); ctx.current_block = new_id; - ctx.new_instruction( - NodeId::dummy(), - NodeId::dummy(), - node::Operation::Nop, - node::ObjectType::NotAnObject, - ); + ctx.new_instruction(node::Operation::Nop, node::ObjectType::NotAnObject); new_id } @@ -134,12 +115,7 @@ pub fn new_unsealed_block(ctx: &mut SsaContext, kind: BlockType, left: bool) -> } ctx.current_block = new_idx; - ctx.new_instruction( - NodeId::dummy(), - NodeId::dummy(), - node::Operation::Nop, - node::ObjectType::NotAnObject, - ); + ctx.new_instruction(node::Operation::Nop, node::ObjectType::NotAnObject); new_idx } diff --git a/crates/noirc_evaluator/src/ssa/code_gen.rs b/crates/noirc_evaluator/src/ssa/code_gen.rs index dd21a58c2eb..098c6276fb6 100644 --- a/crates/noirc_evaluator/src/ssa/code_gen.rs +++ b/crates/noirc_evaluator/src/ssa/code_gen.rs @@ -1,6 +1,5 @@ -use super::block::BlockId; use super::context::SsaContext; -use super::node::{ConstrainOp, Instruction, Node, NodeId, Operation, Variable}; +use super::node::{Binary, BinaryOp, ConstrainOp, Node, NodeId, ObjectType, Operation, Variable}; use super::{block, node, ssa_form}; use std::collections::HashMap; @@ -24,6 +23,8 @@ use noirc_frontend::hir_def::{ use noirc_frontend::node_interner::{DefinitionId, ExprId, NodeInterner, StmtId}; use noirc_frontend::util::vecmap; use noirc_frontend::{FunctionKind, Type}; +use num_bigint::BigUint; +use num_traits::Zero; pub struct IRGenerator<'a> { pub context: SsaContext<'a>, @@ -95,13 +96,15 @@ impl<'a> IRGenerator<'a> { witness: Vec, ) { self.context.mem.create_new_array(len as u32, el_type.into(), name); - let array_idx = (self.context.mem.arrays.len() - 1) as usize; - self.context.mem.arrays[array_idx].def = ident_def; - self.context.mem.arrays[array_idx].values = vecmap(witness, |w| w.into()); + + let array_idx = self.context.mem.last_id(); + + self.context.mem[array_idx].def = ident_def; + self.context.mem[array_idx].values = vecmap(witness, |w| w.into()); let pointer = node::Variable { id: NodeId::dummy(), name: name.to_string(), - obj_type: node::ObjectType::Pointer(array_idx as u32), + obj_type: node::ObjectType::Pointer(array_idx), root: None, def: Some(ident_def), witness: None, @@ -152,12 +155,13 @@ impl<'a> IRGenerator<'a> { Object::Array(a) => { let obj_type = o_type.into(); //We should create an array from 'a' witnesses - self.context.mem.create_array_from_object(&a, ident.id, obj_type, &ident_name); - let array_index = (self.context.mem.arrays.len() - 1) as u32; + let array = + self.context.mem.create_array_from_object(&a, ident.id, obj_type, &ident_name); + node::Variable { id: NodeId::dummy(), name: ident_name.clone(), - obj_type: node::ObjectType::Pointer(array_index), + obj_type: ObjectType::Pointer(array.id), root: None, def: Some(ident.id), witness: None, @@ -165,7 +169,7 @@ impl<'a> IRGenerator<'a> { } } _ => { - let obj_type = node::ObjectType::get_type_from_object(&obj); + let obj_type = ObjectType::get_type_from_object(&obj); //new variable - should be in a let statement? The let statement should set the type node::Variable { id: NodeId::dummy(), @@ -197,36 +201,35 @@ impl<'a> IRGenerator<'a> { let rtype = self.context.get_object_type(rhs); match op { HirUnaryOp::Minus => { - Ok(self.context.new_instruction(self.context.zero(), rhs, Operation::Sub, rtype)) + let lhs = self.context.zero(); + let operator = BinaryOp::Sub { max_rhs_value: BigUint::zero() }; + let op = Operation::Binary(node::Binary { operator, lhs, rhs }); + Ok(self.context.new_instruction(op, rtype)) } - HirUnaryOp::Not => Ok(self.context.new_instruction(rhs, rhs, Operation::Not, rtype)), + HirUnaryOp::Not => Ok(self.context.new_instruction(Operation::Not(rhs), rtype)), } } - fn evaluate_infix_expression( - &mut self, - lhs: NodeId, - rhs: NodeId, - op: HirBinaryOp, - ) -> Result { + fn evaluate_infix_expression(&mut self, lhs: NodeId, rhs: NodeId, op: HirBinaryOp) -> NodeId { let ltype = self.context.get_object_type(lhs); //n.b. we do not verify rhs type as it should have been handled by the type checker. - // Get the opcode from the infix operator - let opcode = node::to_operation(op.kind, ltype); - // Get the result type from the opcode - let optype = self.context.get_result_type(opcode, ltype); - if opcode == node::Operation::Ass { - if let Some(lhs_ins) = self.context.try_get_mut_instruction(lhs) { - if let node::Operation::Load(array) = lhs_ins.operator { - //make it a store rhs - lhs_ins.operator = node::Operation::Store(array); - lhs_ins.lhs = rhs; - return Ok(lhs); - } + if let (HirBinaryOpKind::Assign, Some(lhs_ins)) = + (op.kind, self.context.try_get_mut_instruction(lhs)) + { + if let Operation::Load { array_id, index } = lhs_ins.operation { + //make it a store rhs + lhs_ins.operation = Operation::Store { array_id, index, value: rhs }; + return lhs; } } - Ok(self.context.new_instruction(lhs, rhs, opcode, optype)) + + // Get the opcode from the infix operator + let binary = Binary::from_hir(op.kind, ltype, lhs, rhs); + let opcode = Operation::Binary(binary); + + let optype = self.context.get_result_type(&opcode, ltype); + self.context.new_instruction(opcode, optype) } pub fn evaluate_statement( @@ -327,16 +330,12 @@ impl<'a> IRGenerator<'a> { // HirBinaryOpKind::Multiply => binary_op::handle_mul_op(lhs, rhs, self), // HirBinaryOpKind::Divide => binary_op::handle_div_op(lhs, rhs, self), HirBinaryOpKind::NotEqual => Ok(self.context.new_instruction( - lhs, - rhs, - node::Operation::Constrain(ConstrainOp::Neq), - node::ObjectType::NotAnObject, + Operation::binary(BinaryOp::Constrain(ConstrainOp::Neq), lhs, rhs), + ObjectType::NotAnObject, )), HirBinaryOpKind::Equal => Ok(self.context.new_instruction( - lhs, - rhs, - node::Operation::Constrain(ConstrainOp::Eq), - node::ObjectType::NotAnObject, + Operation::binary(BinaryOp::Constrain(ConstrainOp::Eq), lhs, rhs), + ObjectType::NotAnObject, )), HirBinaryOpKind::And => todo!(), // HirBinaryOpKind::Xor => binary_op::handle_xor_op(lhs, rhs, self), @@ -427,7 +426,7 @@ impl<'a> IRGenerator<'a> { obj_type: node::ObjectType, value_id: NodeId, ) -> Value { - if matches!(obj_type, node::ObjectType::Pointer(_)) { + if matches!(obj_type, ObjectType::Pointer(_)) { if let Ok(rhs_mut) = self.context.get_mut_variable(value_id) { rhs_mut.def = definition_id; rhs_mut.name = variable_name; @@ -440,7 +439,8 @@ impl<'a> IRGenerator<'a> { let id = self.context.add_variable(new_var, None); //Assign rhs to lhs - let result = self.context.new_instruction(id, value_id, node::Operation::Ass, obj_type); + let result = self.context.new_binary_instruction(BinaryOp::Assign, id, value_id, obj_type); + //This new variable should not be available in outer scopes. let cb = self.context.get_current_block_mut(); cb.update_variable(id, result); //update the value array. n.b. we should not update the name as it is the first assignment (let) @@ -485,14 +485,9 @@ impl<'a> IRGenerator<'a> { //ssa: we create a new variable a1 linked to a let new_var_id = self.context.add_variable(new_var, Some(ls_root)); - let rhs = &self.context[rhs_id]; - let r_type = rhs.get_type(); - let result = self.context.new_instruction( - new_var_id, - rhs_id, - node::Operation::Ass, - r_type, - ); + let result_type = self.context[rhs_id].get_type(); + let operation = Operation::binary(BinaryOp::Assign, new_var_id, rhs_id); + let result = self.context.new_instruction(operation, result_type); self.update_variable_id(ls_root, new_var_id, result); //update the name and the value map Value::Single(new_var_id) } @@ -548,20 +543,25 @@ impl<'a> IRGenerator<'a> { let arr_type = self.def_interner().id_type(expr_id); let element_type = arr_type.into(); //WARNING array type! - let array_index = self.context.mem.create_new_array(arr_lit.length as u32, element_type, &String::new()); + let array_id = self.context.mem.create_new_array(arr_lit.length as u32, element_type, &String::new()); //We parse the array definition let elements = self.expression_list_to_objects(env, &arr_lit.contents); - let array = &mut self.context.mem.arrays[array_index as usize]; + let array = &mut self.context.mem[array_id]; let array_adr = array.adr; for (pos, object) in elements.into_iter().enumerate() { //array.witness.push(node::get_witness_from_object(&object)); - let lhs_adr = self.context.get_or_create_const(FieldElement::from((array_adr + pos as u32) as u128), node::ObjectType::Unsigned(32)); - self.context.new_instruction(object, lhs_adr, node::Operation::Store(array_index), element_type); + let lhs_adr = self.context.get_or_create_const(FieldElement::from((array_adr + pos as u32) as u128), ObjectType::Unsigned(32)); + let store = Operation::Store { + array_id, + index: lhs_adr, + value: object, + }; + self.context.new_instruction(store, element_type); } //Finally, we create a variable pointing to this MemArray let new_var = node::Variable { id: NodeId::dummy(), - obj_type : node::ObjectType::Pointer(array_index), + obj_type : ObjectType::Pointer(array_id), name: String::new(), root: None, def: None, @@ -580,14 +580,13 @@ impl<'a> IRGenerator<'a> { // for e.g. struct == struct in the future let lhs = self.expression_to_object(env, &infx.lhs)?.unwrap_id(); let rhs = self.expression_to_object(env, &infx.rhs)?.unwrap_id(); - self.evaluate_infix_expression(lhs, rhs, infx.operator) - .map(Value::Single) + Ok(Value::Single(self.evaluate_infix_expression(lhs, rhs, infx.operator))) }, HirExpression::Cast(cast_expr) => { let lhs = self.expression_to_object(env, &cast_expr.lhs)?.unwrap_id(); let rtype = cast_expr.r#type.into(); - Ok(Value::Single(self.context.new_instruction(lhs, lhs, Operation::Cast, rtype))) + Ok(Value::Single(self.context.new_instruction(Operation::Cast(lhs), rtype))) //We should generate a cast instruction and handle properly type conversion: // unsigned integer to field ; ok, just checks if bit size over FieldElement::max_num_bits() @@ -613,25 +612,20 @@ impl<'a> IRGenerator<'a> { let arr_type = self.def_interner().id_type(arr_def); let o_type = arr_type.into(); - let mut array_index = self.context.mem.arrays.len() as u32; - let array = if let Some(moi) = self.context.mem.find_array(&Some(arr_def)) { - array_index= self.context.mem.get_array_index(moi).unwrap(); - moi - } - else if let Some(Value::Single(pointer)) = self.find_variable(arr_def) { + + let array = if let Some(array) = self.context.mem.find_array(arr_def) { + array + } else if let Some(Value::Single(pointer)) = self.find_variable(arr_def) { match self.context.get_object_type(*pointer) { - node::ObjectType::Pointer(a_id) => { - array_index = a_id; - &self.context.mem.arrays[a_id as usize] - } + ObjectType::Pointer(array_id) => &self.context.mem[array_id], _ => unreachable!(), } - } - else { + } else { let arr = env.get_array(&arr_name).map_err(|kind|kind.add_span(ident_span)).unwrap(); self.context.mem.create_array_from_object(&arr, arr_def, o_type, &arr_name) }; - //let array = self.mem.get_or_create_array(&arr, arr_def.unwrap(), o_type, arr_name); + + let array_id = array.id; let address = array.adr; // Evaluate the index expression @@ -639,8 +633,10 @@ impl<'a> IRGenerator<'a> { let index_type = self.context.get_object_type(index_as_obj); let base_adr = self.context.get_or_create_const(FieldElement::from(address as i128), index_type); - let adr_id = self.context.new_instruction(base_adr, index_as_obj, node::Operation::Add, index_type); - Ok(Value::Single(self.context.new_instruction(adr_id, adr_id, node::Operation::Load(array_index), o_type))) + let adr_id = self.context.new_instruction(Operation::binary(BinaryOp::Add, base_adr, index_as_obj), index_type); + + let load = Operation::Load { array_id, index: adr_id }; + Ok(Value::Single(self.context.new_instruction(load, o_type))) }, HirExpression::Call(call_expr) => { let func_meta = self.def_interner().function_meta(&call_expr.func_id); @@ -822,10 +818,11 @@ impl<'a> IRGenerator<'a> { let iter_type = int_type.into(); let iter_id = self.create_new_variable(iter_name, iter_def, iter_type, None); let iter_var = self.context.get_mut_variable(iter_id).unwrap(); - iter_var.obj_type = iter_type; - let iter_ass = - self.context.new_instruction(iter_id, start_idx, node::Operation::Ass, iter_type); + + let assign = Operation::binary(BinaryOp::Assign, iter_id, start_idx); + let iter_ass = self.context.new_instruction(assign, iter_type); + //We map the iterator to start_idx so that when we seal the join block, we will get the corrdect value. self.update_variable_id(iter_id, iter_ass, start_idx); @@ -834,21 +831,23 @@ impl<'a> IRGenerator<'a> { block::new_unsealed_block(&mut self.context, block::BlockType::ForJoin, true); let exit_id = block::new_sealed_block(&mut self.context, block::BlockType::Normal); self.context.current_block = join_idx; + //should parse a for_expr.condition statement that should evaluate to bool, but //we only supports i=start;i!=end for now //we generate the phi for the iterator because the iterator is manually created - let phi = self.generate_empty_phi(join_idx, iter_id); + let phi = self.context.generate_empty_phi(join_idx, iter_id); self.update_variable_id(iter_id, iter_id, phi); //is it still needed? - let cond = - self.context.new_instruction(phi, end_idx, Operation::Ne, node::ObjectType::Boolean); - let to_fix = self.context.new_instruction( - cond, - NodeId::dummy(), - node::Operation::Jeq, - node::ObjectType::NotAnObject, - ); + + let notequal = Operation::binary(BinaryOp::Ne, phi, end_idx); + let cond = self.context.new_instruction(notequal, ObjectType::Boolean); + + let to_fix = self.context.new_instruction(Operation::Nop, ObjectType::NotAnObject); + //Body let body_id = block::new_sealed_block(&mut self.context, block::BlockType::Normal); + self.context.try_get_mut_instruction(to_fix).unwrap().operation = + Operation::Jeq(cond, body_id); + let block = match self.def_interner().expression(&for_expr.block) { HirExpression::Block(block_expr) => block_expr, _ => panic!("ice: expected a block expression"), @@ -862,7 +861,10 @@ impl<'a> IRGenerator<'a> { //increment iter let one = self.context.get_or_create_const(FieldElement::one(), iter_type); - let incr = self.context.new_instruction(phi, one, node::Operation::Add, iter_type); + + let incr_op = Operation::binary(BinaryOp::Add, phi, one); + let incr = self.context.new_instruction(incr_op, iter_type); + let cur_block_id = self.context.current_block; //It should be the body block, except if the body has CFG statements let cur_block = &mut self.context[cur_block_id]; cur_block.update_variable(iter_id, incr); @@ -870,38 +872,17 @@ impl<'a> IRGenerator<'a> { cur_block.left = Some(join_idx); let join_mut = &mut self.context[join_idx]; join_mut.predecessor.push(cur_block_id); + //jump back to join - self.context.new_instruction( - NodeId::dummy(), - self.context[join_idx].get_first_instruction(), - node::Operation::Jmp, - node::ObjectType::NotAnObject, - ); + self.context.new_instruction(Operation::Jmp(join_idx), ObjectType::NotAnObject); + //seal join ssa_form::seal_block(&mut self.context, join_idx); //exit block self.context.current_block = exit_id; let exit_first = self.context.get_current_block().get_first_instruction(); block::link_with_target(&mut self.context, join_idx, Some(exit_id), Some(body_id)); - let first_instruction = self.context[body_id].get_first_instruction(); - self.context.try_get_mut_instruction(to_fix).unwrap().rhs = first_instruction; - Ok(Value::Single(exit_first)) //TODO what should we return??? - } - pub fn generate_empty_phi(&mut self, target_block: BlockId, root: NodeId) -> NodeId { - //Ensure there is not already a phi for the variable (n.b. probably not usefull) - for i in &self.context[target_block].instructions { - if let Some(ins) = self.context.try_get_instruction(*i) { - if ins.operator == node::Operation::Phi && ins.rhs == root { - return *i; - } - } - } - - let v_type = self.context.get_object_type(root); - let new_phi = Instruction::new(Operation::Phi, root, root, v_type, Some(target_block)); - let phi_id = self.context.add_instruction(new_phi); - self.context[target_block].instructions.insert(1, phi_id); - phi_id + Ok(Value::Single(exit_first)) //TODO what should we return??? } } diff --git a/crates/noirc_evaluator/src/ssa/context.rs b/crates/noirc_evaluator/src/ssa/context.rs index 3907a5f7245..86aecd3630e 100644 --- a/crates/noirc_evaluator/src/ssa/context.rs +++ b/crates/noirc_evaluator/src/ssa/context.rs @@ -1,18 +1,19 @@ use super::block::{BasicBlock, BlockId}; use super::function::SSAFunction; use super::mem::Memory; -use super::node::{Instruction, NodeId, NodeObj, ObjectType, Operation}; +use super::node::{BinaryOp, Instruction, NodeId, NodeObj, ObjectType, Operation}; use super::{block, flatten, integer, node, optim}; use std::collections::{HashMap, HashSet}; use super::super::errors::RuntimeError; use crate::ssa::acir_gen::Acir; use crate::ssa::function; -use crate::ssa::node::Node; +use crate::ssa::node::{Mark, Node}; use crate::Evaluator; use acvm::FieldElement; use noirc_frontend::hir::Context; use noirc_frontend::node_interner::FuncId; +use noirc_frontend::util::vecmap; use num_bigint::BigUint; use num_traits::Zero; @@ -69,6 +70,94 @@ impl<'a> SsaContext<'a> { } } + fn binary_to_string(&self, binary: &node::Binary) -> String { + let lhs = self.node_to_string(binary.lhs); + let rhs = self.node_to_string(binary.rhs); + + let op = match &binary.operator { + BinaryOp::Add => "add", + BinaryOp::SafeAdd => "safe_add", + BinaryOp::Sub { .. } => "sub", + BinaryOp::SafeSub { .. } => "safe_sub", + BinaryOp::Mul => "mul", + BinaryOp::SafeMul => "safe_mul", + BinaryOp::Udiv => "udiv", + BinaryOp::Sdiv => "sdiv", + BinaryOp::Urem => "urem", + BinaryOp::Srem => "srem", + BinaryOp::Div => "div", + BinaryOp::Eq => "eq", + BinaryOp::Ne => "ne", + BinaryOp::Ult => "ult", + BinaryOp::Ule => "ule", + BinaryOp::Slt => "slt", + BinaryOp::Sle => "sle", + BinaryOp::Lt => "lt", + BinaryOp::Lte => "lte", + BinaryOp::And => "and", + BinaryOp::Or => "or", + BinaryOp::Xor => "xor", + BinaryOp::Assign => "assign", + BinaryOp::Constrain(node::ConstrainOp::Eq) => "constrain_eq", + BinaryOp::Constrain(node::ConstrainOp::Neq) => "constrain_neq", + BinaryOp::Shl => "shl", + BinaryOp::Shr => "shr", + }; + + format!("{} {}, {}", op, lhs, rhs) + } + + fn operation_to_string(&self, op: &Operation) -> String { + let join = |args: &[NodeId]| vecmap(args, |arg| self.node_to_string(*arg)).join(", "); + + match op { + Operation::Binary(binary) => self.binary_to_string(binary), + Operation::Cast(value) => format!("cast {}", self.node_to_string(*value)), + Operation::Truncate { value, bit_size, max_bit_size } => { + format!( + "truncate {}, bitsize = {}, max bitsize = {}", + self.node_to_string(*value), + bit_size, + max_bit_size + ) + } + Operation::Not(v) => format!("not {}", self.node_to_string(*v)), + Operation::Jne(v, b) => format!("jne {}, {:?}", self.node_to_string(*v), b), + Operation::Jeq(v, b) => format!("jeq {}, {:?}", self.node_to_string(*v), b), + Operation::Jmp(b) => format!("jmp {:?}", b), + Operation::Phi { root, block_args } => { + let mut s = format!("phi {}", self.node_to_string(*root)); + for (value, block) in block_args { + s += &format!( + ", {} from block {}", + self.node_to_string(*value), + block.0.into_raw_parts().0 + ); + } + s + } + Operation::Load { array_id, index } => { + format!("load {:?}, index {}", array_id, self.node_to_string(*index)) + } + Operation::Store { array_id, index, value } => { + format!( + "store {:?}, index {}, value {}", + array_id, + self.node_to_string(*index), + self.node_to_string(*value) + ) + } + Operation::Intrinsic(opcode, args) => format!("intrinsic {}({})", opcode, join(args)), + Operation::Nop => "nop".into(), + Operation::Call(f, args) => format!("call {:?}({})", f, join(args)), + Operation::Return(values) => format!("return ({})", join(values)), + Operation::Result { call_instruction, index } => { + let call = self.node_to_string(*call_instruction); + format!("result {} of {}", index, call) + } + } + } + pub fn print_block(&self, b: &block::BasicBlock) { for id in &b.instructions { let ins = self.get_instruction(*id); @@ -77,21 +166,11 @@ impl<'a> SsaContext<'a> { } else { ins.res_name.clone() }; - if ins.is_deleted { - str_res += " -DELETED"; - } - let lhs_str = self.node_to_string(ins.lhs); - let rhs_str = self.node_to_string(ins.rhs); - let mut ins_str = format!("{} op:{:?} {}", lhs_str, ins.operator, rhs_str); - - if ins.operator == node::Operation::Phi { - ins_str += "("; - for (v, b) in &ins.phi_arguments { - ins_str += - &format!("{:?}:{:?}, ", v.0.into_raw_parts().0, b.0.into_raw_parts().0); - } - ins_str += ")"; + if let Mark::ReplaceWith(replacement) = ins.mark { + str_res = format!("{} -REPLACED with id {:?}", str_res, replacement.0); } + + let ins_str = self.operation_to_string(&ins.operation); println!("{}: {}", str_res, ins_str); } } @@ -219,6 +298,23 @@ impl<'a> SsaContext<'a> { } } + pub fn get_result_instruction( + &mut self, + target: BlockId, + call_instruction: NodeId, + index: u32, + ) -> Option<&mut Instruction> { + for id in &self.blocks[target.0].instructions { + if let Some(NodeObj::Instr(i)) = self.nodes.get(id.0) { + if i.operation == (Operation::Result { call_instruction, index }) { + let id = *id; + return self.try_get_mut_instruction(id); + } + } + } + None + } + pub fn get_root_value(&self, id: NodeId) -> NodeId { self.get_variable(id).map(|v| v.get_root()).unwrap_or(id) } @@ -235,45 +331,29 @@ impl<'a> SsaContext<'a> { id } - pub fn new_instruction( - &mut self, - lhs: NodeId, - rhs: NodeId, - opcode: node::Operation, - optype: node::ObjectType, - ) -> NodeId { - let operands = vec![lhs, rhs]; - self.new_instruction_with_multiple_operands(operands, opcode, optype) - } - - pub fn new_instruction_with_multiple_operands( - &mut self, - mut operands: Vec, - opcode: node::Operation, - optype: node::ObjectType, - ) -> NodeId { - while operands.len() < 2 { - operands.push(NodeId::dummy()); - } + pub fn new_instruction(&mut self, opcode: Operation, optype: ObjectType) -> NodeId { //Add a new instruction to the nodes arena - let mut i = node::Instruction::new( - opcode, - operands[0], - operands[1], - optype, - Some(self.current_block), - ); + let mut i = Instruction::new(opcode, optype, Some(self.current_block)); //Basic simplification optim::simplify(self, &mut i); - if operands.len() > 2 { - i.ins_arguments = operands; - } - if i.is_deleted { - return i.rhs; + + if let Mark::ReplaceWith(replacement) = i.mark { + return replacement; } self.push_instruction(i) } + pub fn new_binary_instruction( + &mut self, + operator: BinaryOp, + lhs: NodeId, + rhs: NodeId, + optype: ObjectType, + ) -> NodeId { + let operation = Operation::binary(operator, lhs, rhs); + self.new_instruction(operation, optype) + } + pub fn find_const_with_type( &self, value: &BigUint, @@ -307,30 +387,27 @@ impl<'a> SsaContext<'a> { } //Return the type of the operation result, based on the left hand type - pub fn get_result_type(&self, op: Operation, lhs_type: node::ObjectType) -> node::ObjectType { + pub fn get_result_type(&self, op: &Operation, lhs_type: node::ObjectType) -> node::ObjectType { + use {BinaryOp::*, Operation::*}; match op { - Operation::Eq - | Operation::Ne - | Operation::Ugt - | Operation::Uge - | Operation::Ult - | Operation::Ule - | Operation::Sgt - | Operation::Sge - | Operation::Slt - | Operation::Sle - | Operation::Lt - | Operation::Gt - | Operation::Lte - | Operation::Gte => ObjectType::Boolean, - Operation::Jne - | Operation::Jeq - | Operation::Jmp + Binary(node::Binary { operator: Eq, .. }) + | Binary(node::Binary { operator: Ne, .. }) + | Binary(node::Binary { operator: Ult, .. }) + | Binary(node::Binary { operator: Ule, .. }) + | Binary(node::Binary { operator: Slt, .. }) + | Binary(node::Binary { operator: Sle, .. }) + | Binary(node::Binary { operator: Lt, .. }) + | Binary(node::Binary { operator: Lte, .. }) => ObjectType::Boolean, + Operation::Jne(_, _) + | Operation::Jeq(_, _) + | Operation::Jmp(_) | Operation::Nop - | Operation::Constrain(_) - | Operation::Store(_) => ObjectType::NotAnObject, - Operation::Load(adr) => self.mem.arrays[adr as usize].element_type, - Operation::Cast | Operation::Trunc => unreachable!("cannot determine result type"), + | Binary(node::Binary { operator: Constrain(_), .. }) + | Operation::Store { .. } => ObjectType::NotAnObject, + Operation::Load { array_id, .. } => self.mem[*array_id].element_type, + Operation::Cast(_) | Operation::Truncate { .. } => { + unreachable!("cannot determine result type") + } _ => lhs_type, } } @@ -353,44 +430,44 @@ impl<'a> SsaContext<'a> { self.blocks.iter().map(|(_id, block)| block) } - pub fn pause(&self, interactive: bool, before: &str, after: &str) { - if_debug::if_debug!(if interactive { + pub fn log(&self, show_log: bool, before: &str, after: &str) { + if show_log { self.print(before); - let mut number = String::new(); - println!("Press enter to continue"); - std::io::stdin().read_line(&mut number).unwrap(); println!("{}", after); - }); + } } //Optimise, flatten and truncate IR and then generates ACIR representation from it pub fn ir_to_acir( &mut self, evaluator: &mut Evaluator, - interactive: bool, + enable_logging: bool, ) -> Result<(), RuntimeError> { //SSA - self.pause(interactive, "SSA:", "inline functions"); + self.log(enable_logging, "SSA:", "\ninline functions"); flatten::inline_all_functions(self); + //Optimisation block::compute_dom(self); optim::cse(self, self.first_block); - self.pause(interactive, "CSE:", "unrolling:"); + self.log(enable_logging, "\nCSE:", "\nunrolling:"); //Unrolling flatten::unroll_tree(self, self.first_block); + //Inlining - self.pause(interactive, "", "inlining:"); + self.log(enable_logging, "", "\ninlining:"); flatten::inline_tree(self, self.first_block); optim::cse(self, self.first_block); + //Truncation integer::overflow_strategy(self); - self.pause(interactive, "overflow:", ""); + self.log(enable_logging, "\noverflow:", ""); //ACIR self.acir(evaluator); - if_debug::if_debug!(if interactive { - dbg!("DONE"); + if enable_logging { + println!("DONE"); dbg!(&evaluator.current_witness_index); - }); + } Ok(()) } @@ -408,18 +485,22 @@ impl<'a> SsaContext<'a> { Acir::print_circuit(&evaluator.gates); } - pub fn generate_empty_phi(&mut self, target_block: BlockId, root: NodeId) -> NodeId { + pub fn generate_empty_phi(&mut self, target_block: BlockId, phi_root: NodeId) -> NodeId { //Ensure there is not already a phi for the variable (n.b. probably not usefull) for i in &self[target_block].instructions { - if let Some(ins) = self.try_get_instruction(*i) { - if ins.operator == node::Operation::Phi && ins.rhs == root { + match self.try_get_instruction(*i) { + Some(Instruction { operation: Operation::Phi { root, .. }, .. }) + if *root == phi_root => + { return *i; } + _ => (), } } - let v_type = self.get_object_type(root); - let new_phi = Instruction::new(Operation::Phi, root, root, v_type, Some(target_block)); + let v_type = self.get_object_type(phi_root); + let operation = Operation::Phi { root: phi_root, block_args: vec![] }; + let new_phi = Instruction::new(operation, v_type, Some(target_block)); let phi_id = self.add_instruction(new_phi); self[target_block].instructions.insert(1, phi_id); phi_id diff --git a/crates/noirc_evaluator/src/ssa/flatten.rs b/crates/noirc_evaluator/src/ssa/flatten.rs index dd223bc946a..e8b3c9c1ca7 100644 --- a/crates/noirc_evaluator/src/ssa/flatten.rs +++ b/crates/noirc_evaluator/src/ssa/flatten.rs @@ -2,12 +2,24 @@ use super::{ block::{self, BlockId}, context::SsaContext, function, - node::{self, Node, NodeEval, NodeId, NodeObj, Operation}, + mem::ArrayId, + node::{ + self, BinaryOp, Instruction, Mark, Node, NodeEval, NodeId, NodeObj, ObjectType, Operation, + }, optim, }; use acvm::FieldElement; +use std::collections::{hash_map::Entry, HashMap}; + +//returns the NodeObj index of a NodeEval object +//if NodeEval is a constant, it may creates a new NodeObj corresponding to the constant value +fn to_index(ctx: &mut SsaContext, obj: NodeEval) -> NodeId { + match obj { + NodeEval::Const(c, t) => ctx.get_or_create_const(c, t), + NodeEval::VarOrInstruction(i) => i, + } +} use noirc_frontend::util::vecmap; -use std::collections::HashMap; // Number of allowed times for inlining function calls inside a code block. // If a function calls another function, the inlining of the first function will leave the second function call that needs to be inlined as well. @@ -25,34 +37,29 @@ pub fn unroll_tree(ctx: &mut SsaContext, mut block_id: BlockId) { } //Update the block instruction list using the eval_map -fn eval_block(block_id: BlockId, eval_map: &HashMap, igen: &mut SsaContext) { - for i in &igen[block_id].instructions.clone() { - //RIA - if let Some(ins) = igen.try_get_mut_instruction(*i) { - if let Some(value) = eval_map.get(&ins.rhs) { - ins.rhs = value.into_node_id().unwrap(); - } - if let Some(value) = eval_map.get(&ins.lhs) { - ins.lhs = value.into_node_id().unwrap(); - } - //TODO simplify(ctx, ins); +fn eval_block(block_id: BlockId, eval_map: &HashMap, ctx: &mut SsaContext) { + for i in &ctx[block_id].instructions.clone() { + if let Some(ins) = ctx.try_get_mut_instruction(*i) { + ins.operation = update_operator(&ins.operation, eval_map); + // TODO: simplify(ctx, ins); } } } +fn update_operator(operator: &Operation, eval_map: &HashMap) -> Operation { + operator.map_id(|id| eval_map.get(&id).and_then(|value| value.into_node_id()).unwrap_or(id)) +} + pub fn unroll_block( unrolled_instructions: &mut Vec, eval_map: &mut HashMap, block_to_unroll: BlockId, - igen: &mut SsaContext, + ctx: &mut SsaContext, ) -> Option { - if igen[block_to_unroll].is_join() { - unroll_join(unrolled_instructions, eval_map, block_to_unroll, igen) - } else if let Some(i) = unroll_std_block(unrolled_instructions, eval_map, block_to_unroll, igen) - { - igen.try_get_instruction(i).map(|ins| ins.parent_block) + if ctx[block_to_unroll].is_join() { + unroll_join(unrolled_instructions, eval_map, block_to_unroll, ctx) } else { - None + unroll_std_block(unrolled_instructions, eval_map, block_to_unroll, ctx) } } @@ -61,51 +68,48 @@ pub fn unroll_std_block( unrolled_instructions: &mut Vec, eval_map: &mut HashMap, block_to_unroll: BlockId, - igen: &mut SsaContext, -) -> Option //first instruction of the left block + ctx: &mut SsaContext, +) -> Option // The left block { - let block = &igen[block_to_unroll]; + let block = &ctx[block_to_unroll]; let b_instructions = block.instructions.clone(); - let mut next = None; - if let Some(left) = block.left { - if let Some(f) = igen[left].instructions.first() { - next = Some(*f); - } - } + let next = block.left; + for i_id in &b_instructions { - match &igen[*i_id] { + match &ctx[*i_id] { node::NodeObj::Instr(i) => { - let new_left = get_current_value(i.lhs, eval_map).into_node_id().unwrap(); - let new_right = get_current_value(i.rhs, eval_map).into_node_id().unwrap(); + let new_op = i + .operation + .map_id(|id| get_current_value(id, eval_map).into_node_id().unwrap()); let mut new_ins = node::Instruction::new( - i.operator, new_left, new_right, i.res_type, None, //TODO to fix later + new_op, i.res_type, None, //TODO to fix later ); - match i.operator { - Operation::Ass => { + match i.operation { + Operation::Binary(node::Binary { operator: BinaryOp::Assign, .. }) => { unreachable!("unsupported instruction type when unrolling: assign"); //To support assignments, we should create a new variable and updates the eval_map with it //however assignments should have already been removed by copy propagation. } - Operation::Jmp => { - return Some(i.rhs); + Operation::Jmp(block) => { + return Some(block); } Operation::Nop => (), _ => { - optim::simplify(igen, &mut new_ins); - let result_id; - let mut to_delete = false; - if new_ins.is_deleted { - result_id = new_ins.rhs; - if new_ins.rhs == new_ins.id { - to_delete = true; + optim::simplify(ctx, &mut new_ins); + + match new_ins.mark { + Mark::None => { + let id = ctx.add_instruction(new_ins); + unrolled_instructions.push(id); + eval_map.insert(*i_id, NodeEval::VarOrInstruction(id)); + } + Mark::Deleted => (), + Mark::ReplaceWith(replacement) => { + // TODO: Should we insert into unrolled_instructions as well? + // If optim::simplify replaces with a constant then we should not, + // otherwise it may make sense if it is not already inserted. + eval_map.insert(*i_id, NodeEval::VarOrInstruction(replacement)); } - } else { - result_id = igen.add_instruction(new_ins); - unrolled_instructions.push(result_id); - } - //ignore self-deleted instructions - if !to_delete { - eval_map.insert(*i_id, NodeEval::VarOrInstruction(result_id)); } } } @@ -125,10 +129,10 @@ pub fn unroll_join( unrolled_instructions: &mut Vec, eval_map: &mut HashMap, block_to_unroll: BlockId, - igen: &mut SsaContext, + ctx: &mut SsaContext, ) -> Option { //Returns the exit block of the loop - let join = &igen[block_to_unroll]; + let join = &ctx[block_to_unroll]; let join_instructions = join.instructions.clone(); let join_left = join.left; //XXX.clone(); let prev = *join.predecessor.first().unwrap(); //todo predecessor.first or .last? @@ -141,12 +145,12 @@ pub fn unroll_join( } while { //evaluate the join block: - evaluate_phi(&join_instructions, from, eval_map, igen); - evaluate_conditional_jump(*join_instructions.last().unwrap(), eval_map, igen) + evaluate_phi(&join_instructions, from, eval_map, ctx); + evaluate_conditional_jump(*join_instructions.last().unwrap(), eval_map, ctx) } { from = block_to_unroll; let mut b_id = body_id; - while let Some(next) = unroll_block(unrolled_instructions, eval_map, b_id, igen) { + while let Some(next) = unroll_block(unrolled_instructions, eval_map, b_id, ctx) { //process next block: from = b_id; b_id = next; @@ -164,39 +168,36 @@ pub fn outer_unroll( unroll_ins: &mut Vec, //unrolled instructions eval_map: &mut HashMap, block_id: BlockId, //block to unroll - igen: &mut SsaContext, + ctx: &mut SsaContext, ) -> Option //next block { assert!(unroll_ins.is_empty()); - let block = &igen[block_id]; + let block = &ctx[block_id]; let b_right = block.right; let b_left = block.left; let block_instructions = block.instructions.clone(); if block.is_join() { //1. unroll the block into the unroll_ins - unroll_join(unroll_ins, eval_map, block_id, igen); + unroll_join(unroll_ins, eval_map, block_id, ctx); //2. map the Phis variables to their unrolled values: for ins in &block_instructions { - if let Some(ins_obj) = igen.try_get_instruction(*ins) { - if ins_obj.operator == node::Operation::Phi { - if eval_map.contains_key(&ins_obj.rhs) { - eval_map.insert(ins_obj.lhs, eval_map[&ins_obj.rhs]); - //todo test with constants - } else if eval_map.contains_key(&ins_obj.id) { - // unroll_map.insert(ins_obj.lhs, eval_map[&ins_obj.idx].to_index().unwrap()); - eval_map.insert(ins_obj.lhs, eval_map[&ins_obj.id]); + if let Some(ins_obj) = ctx.try_get_instruction(*ins) { + if let Operation::Phi { root, .. } = &ins_obj.operation { + if let Some(node_eval) = eval_map.get(&ins_obj.id) { + let node_eval = *node_eval; + eval_map.entry(*root).or_insert(node_eval); //todo test with constants } - } else if ins_obj.operator != node::Operation::Nop { + } else if ins_obj.operation != node::Operation::Nop { break; //no more phis } } } //3. Merge the unrolled blocks into the join for ins in unroll_ins.iter() { - igen[*ins].set_id(*ins); + ctx[*ins].set_id(*ins); } - let join_mut = &mut igen[block_id]; + let join_mut = &mut ctx[block_id]; join_mut.instructions = unroll_ins.clone(); join_mut.right = None; join_mut.kind = block::BlockType::Normal; @@ -210,9 +211,9 @@ pub fn outer_unroll( } //we get the subgraph, however we could retrieve the list of processed blocks directly in unroll_join (cf. processed) if let Some(body_id) = b_right { - let sub_graph = block::bfs(body_id, Some(block_id), igen); + let sub_graph = block::bfs(body_id, Some(block_id), ctx); for b in sub_graph { - igen.remove_block(b); + ctx.remove_block(b); } } @@ -220,7 +221,7 @@ pub fn outer_unroll( unroll_ins.clear(); } else { //We update block instructions from the eval_map - eval_block(block_id, eval_map, igen); + eval_block(block_id, eval_map, ctx); } b_left //returns the next block to process } @@ -230,28 +231,28 @@ fn evaluate_phi( instructions: &[NodeId], from: BlockId, to: &mut HashMap, - igen: &mut SsaContext, + ctx: &mut SsaContext, ) { for i in instructions { let mut to_process = Vec::new(); - if let Some(ins) = igen.try_get_instruction(*i) { - if ins.operator == node::Operation::Phi { - for phi in &ins.phi_arguments { - if phi.1 == from { + if let Some(ins) = ctx.try_get_instruction(*i) { + if let Operation::Phi { block_args, .. } = &ins.operation { + for (arg, block) in block_args { + if *block == from { //we evaluate the phi instruction value - to_process.push(( - ins.id, - evaluate_one(NodeEval::VarOrInstruction(phi.0), to, igen), - )); + let arg = *arg; + let id = ins.id; + to_process + .push((id, evaluate_one(NodeEval::VarOrInstruction(arg), to, ctx))); } } - } else if ins.operator != node::Operation::Nop { + } else if ins.operation != node::Operation::Nop { break; //phi instructions are placed at the beginning (and after the first dummy instruction) } } //Update the evaluation map. for obj in to_process { - to.insert(obj.0, NodeEval::VarOrInstruction(optim::to_index(igen, obj.1))); + to.insert(obj.0, NodeEval::VarOrInstruction(to_index(ctx, obj.1))); } } } @@ -260,22 +261,27 @@ fn evaluate_phi( fn evaluate_conditional_jump( jump: NodeId, value_array: &mut HashMap, - ctx: &SsaContext, + ctx: &mut SsaContext, ) -> bool { let jump_ins = ctx.try_get_instruction(jump).unwrap(); - let lhs = get_current_value(jump_ins.lhs, value_array); - let cond = evaluate_object(lhs, value_array, ctx); - if let Some(cond_const) = cond.into_const_value() { - let result = !cond_const.is_zero(); - match jump_ins.operator { - node::Operation::Jeq => return result, - node::Operation::Jne => return !result, - node::Operation::Jmp => return true, - _ => panic!("loop without conditional statement!"), //TODO shouldn't we return false instead? - } - } - unreachable!("Condition should be constant"); + let (cond_id, should_jump): (_, fn(FieldElement) -> bool) = match jump_ins.operation { + Operation::Jeq(cond_id, _) => (cond_id, |field| !field.is_zero()), + Operation::Jne(cond_id, _) => (cond_id, |field| field.is_zero()), + Operation::Jmp(_) => return true, + _ => panic!("loop without conditional statement!"), //TODO shouldn't we return false instead? + }; + + let cond = get_current_value(cond_id, value_array); + let cond = match evaluate_object(cond, value_array, ctx).into_const_value() { + Some(c) => c, + None => unreachable!( + "Conditional jump argument is non-const: {:?}", + evaluate_object(cond, value_array, ctx) + ), + }; + + should_jump(cond) } //Retrieve the NodeEval value of the index in the evaluation map @@ -298,28 +304,26 @@ fn get_current_value_for_node_eval( fn evaluate_one( obj: NodeEval, value_array: &HashMap, - igen: &SsaContext, + ctx: &SsaContext, ) -> NodeEval { match get_current_value_for_node_eval(obj, value_array) { NodeEval::Const(_, _) => obj, NodeEval::VarOrInstruction(obj_id) => { - if igen.try_get_node(obj_id).is_none() { + if ctx.try_get_node(obj_id).is_none() { return obj; } - match &igen[obj_id] { + match &ctx[obj_id] { NodeObj::Instr(i) => { - if i.operator == node::Operation::Phi { + if let Operation::Phi { .. } = i.operation { //n.b phi are handled before, else we should know which block we come from dbg!(i.id); return NodeEval::VarOrInstruction(i.id); } - let lhs = get_current_value(i.lhs, value_array); - let lhr = get_current_value(i.rhs, value_array); - let result = i.evaluate(&lhs, &lhr); + let result = i.evaluate_with(ctx, |_, id| get_current_value(id, value_array)); if let NodeEval::VarOrInstruction(idx) = result { - if igen.try_get_node(idx).is_none() { + if ctx.try_get_node(idx).is_none() { return NodeEval::VarOrInstruction(obj_id); } } @@ -339,30 +343,30 @@ fn evaluate_one( fn evaluate_object( obj: NodeEval, value_array: &HashMap, - igen: &SsaContext, + ctx: &SsaContext, ) -> NodeEval { match get_current_value_for_node_eval(obj, value_array) { NodeEval::Const(_, _) => obj, NodeEval::VarOrInstruction(obj_id) => { - if igen.try_get_node(obj_id).is_none() { + if ctx.try_get_node(obj_id).is_none() { dbg!(obj_id); return obj; } - match &igen[obj_id] { + match &ctx[obj_id] { NodeObj::Instr(i) => { - if i.operator == Operation::Phi { + if let Operation::Phi { .. } = i.operation { dbg!(i.id); return NodeEval::VarOrInstruction(i.id); } + //n.b phi are handled before, else we should know which block we come from - let lhs = - evaluate_object(get_current_value(i.lhs, value_array), value_array, igen); - let lhr = - evaluate_object(get_current_value(i.rhs, value_array), value_array, igen); - let result = i.evaluate(&lhs, &lhr); + let result = i.evaluate_with(ctx, |ctx, id| { + evaluate_object(get_current_value(id, value_array), value_array, ctx) + }); + if let NodeEval::VarOrInstruction(idx) = result { - if igen.try_get_node(idx).is_none() { + if ctx.try_get_node(idx).is_none() { return NodeEval::VarOrInstruction(obj_id); } } @@ -420,24 +424,22 @@ pub fn inline_all_functions(ctx: &mut SsaContext) { } } -//inline all function calls of the block //Return false if some inlined function performs a function call fn inline_block(ctx: &mut SsaContext, block_id: BlockId) -> bool { - let mut call_ins = Vec::::new(); + let mut call_ins = vec![]; for i in &ctx[block_id].instructions { if let Some(ins) = ctx.try_get_instruction(*i) { - if !ins.is_deleted && matches!(ins.operator, node::Operation::Call(_)) { - call_ins.push(*i); + if !ins.is_deleted() { + if let Operation::Call(f, args) = &ins.operation { + call_ins.push((ins.id, *f, args.clone(), ins.parent_block)); + } } } } let mut result = true; - for ins_id in call_ins { - let ins = ctx.try_get_instruction(ins_id).unwrap().clone(); - if let node::Instruction { operator: node::Operation::Call(f), .. } = ins { - if !inline(f, &ins.ins_arguments, ctx, ins.parent_block, ins.id) { - result = false; - } + for (ins_id, f, args, parent_block) in call_ins { + if !inline(f, &args, ctx, parent_block, ins_id) { + result = false; } } optim::cse(ctx, block_id); //handles the deleted call instructions @@ -456,8 +458,8 @@ pub fn inline( let ssa_func = ctx.get_ssafunc(func_id).unwrap(); //map nodes from the function cfg to the caller cfg - let mut inline_map: HashMap = HashMap::new(); - let mut array_map: HashMap = HashMap::new(); + let mut inline_map = HashMap::::new(); + let mut array_map = HashMap::::new(); //1. map function parameters for (arg_caller, arg_function) in args.iter().zip(&ssa_func.arguments) { inline_map.insert(*arg_function, *arg_caller); @@ -488,7 +490,7 @@ pub fn inline_in_block( block_id: BlockId, target_block_id: BlockId, inline_map: &mut HashMap, - array_map: &mut HashMap, + array_map: &mut HashMap, call_id: NodeId, nested_call: &mut bool, ctx: &mut SsaContext, @@ -499,176 +501,127 @@ pub fn inline_in_block( let block_func_instructions = &block_func.instructions.clone(); *nested_call = false; for &i_id in block_func_instructions { - let mut array_func = None; - let mut array_func_idx = u32::MAX; if let Some(ins) = ctx.try_get_instruction(i_id) { - if ins.is_deleted { + if ins.is_deleted() { continue; } - let clone = ins.clone(); - if let node::ObjectType::Pointer(a) = ins.res_type { + let mut array = None; + let mut array_id = None; + let mut clone = ins.clone(); + + if let node::ObjectType::Pointer(id) = ins.res_type { //We need to map arrays to arrays via the array_map, we collect the data here to be mapped below. - array_func = Some(ctx.mem.arrays[a as usize].clone()); - array_func_idx = a; - } else if let Operation::Load(a) = ins.operator { - array_func = Some(ctx.mem.arrays[a as usize].clone()); - array_func_idx = a; - } else if let Operation::Store(a) = ins.operator { - array_func = Some(ctx.mem.arrays[a as usize].clone()); - array_func_idx = a; + array = Some(ctx.mem[id].clone()); + array_id = Some(id); + } else if let Operation::Load { array_id: id, .. } = ins.operation { + array = Some(ctx.mem[id].clone()); + array_id = Some(id); + } else if let Operation::Store { array_id: id, .. } = ins.operation { + array = Some(ctx.mem[id].clone()); + array_id = Some(id); } - let new_left = - function::SSAFunction::get_mapped_value(Some(&clone.lhs), ctx, inline_map); - let new_right = - function::SSAFunction::get_mapped_value(Some(&clone.rhs), ctx, inline_map); - let new_arg = function::SSAFunction::get_mapped_value( - clone.ins_arguments.first(), - ctx, - inline_map, - ); + clone.operation.map_id_mut(|id| { + function::SSAFunction::get_mapped_value(Some(&id), ctx, inline_map) + }); + //Arrays are mapped to array. We create the array if not mapped - if let Some(a) = array_func { - if let std::collections::hash_map::Entry::Vacant(e) = - array_map.entry(array_func_idx) - { - let i_pointer = ctx.mem.create_new_array(a.len, a.element_type, &a.name); + if let (Some(array), Some(array_id)) = (array, array_id) { + if let Entry::Vacant(e) = array_map.entry(array_id) { + let new_id = + ctx.mem.create_new_array(array.len, array.element_type, &array.name); //We populate the array (if possible) using the inline map - for i in &a.values { + for i in &array.values { if let Some(f) = i.to_const() { - ctx.mem.arrays[i_pointer as usize] - .values - .push(super::acir_gen::InternalVar::from(f)); + ctx.mem[new_id].values.push(super::acir_gen::InternalVar::from(f)); } //todo: else use inline map. } - e.insert(i_pointer); + e.insert(new_id); }; } - match clone.operator { + match &clone.operation { Operation::Nop => (), //Return instruction: - Operation::Ret => { + Operation::Return(values) => { //we need to find the corresponding result instruction in the target block (using ins.rhs) and replace it by ins.lhs - if let Some(ret_id) = ctx[target_block_id].get_result_instruction(call_id, ctx) - { - //we support only one result for now, should use 'ins.lhs.get_value()' - if let node::NodeObj::Instr(i) = &mut ctx[ret_id] { - i.is_deleted = true; - i.rhs = new_left; //Then we need to ensure there is a CSE. - } - } else { - //we use the call instruction instead - //we could use the ins_arguments to get the results here, and since we have the input arguments (in the ssafunction) we know how many there are. - //for now the call instruction is replaced by the (one) result - let call_ins = ctx.get_mut_instruction(call_id); - call_ins.is_deleted = true; - call_ins.rhs = new_arg; - if array_map.contains_key(&array_func_idx) { - let i_pointer = array_map[&array_func_idx]; - call_ins.res_type = node::ObjectType::Pointer(i_pointer); - } + for (i, value) in values.iter().enumerate() { + ctx.get_result_instruction(target_block_id, call_id, i as u32) + .unwrap() + .mark = Mark::ReplaceWith(*value); } } - Operation::Call(_) => { + Operation::Call(..) => { *nested_call = true; - - let mut new_ins = node::Instruction::new( - clone.operator, - new_left, - new_right, - clone.res_type, - Some(target_block_id), - ); - new_ins.ins_arguments = Vec::new(); - for i in clone.ins_arguments { - new_ins.ins_arguments.push(function::SSAFunction::get_mapped_value( - Some(&i), - ctx, - inline_map, - )); - } - let result_id = ctx.add_instruction(new_ins); - new_instructions.push(result_id); - inline_map.insert(i_id, result_id); + let new_ins = new_cloned_instruction(clone, target_block_id); + push_instruction(ctx, new_ins, &mut new_instructions, inline_map); } - Operation::Load(a) => { + Operation::Load { array_id, index } => { //Compute the new address: //TODO use relative addressing, but that requires a few changes, mainly in acir_gen.rs and integer.rs - let b = array_map[&a]; + let b = array_map[array_id]; //n.b. this offset is always positive - let offset = ctx.mem.arrays[b as usize].adr - ctx.mem.arrays[a as usize].adr; - let index_type = ctx[new_left].get_type(); + let offset = ctx.mem[b].adr - ctx.mem[*array_id].adr; + let index_type = ctx[*index].get_type(); let offset_id = ctx.get_or_create_const(FieldElement::from(offset as i128), index_type); - let adr_id = - ctx.new_instruction(offset_id, new_left, node::Operation::Add, index_type); - let new_ins = node::Instruction::new( - node::Operation::Load(array_map[&a]), - adr_id, - adr_id, + + let add = node::Binary { operator: BinaryOp::Add, lhs: offset_id, rhs: *index }; + let adr_id = ctx.new_instruction(Operation::Binary(add), index_type); + let new_ins = Instruction::new( + Operation::Load { array_id: array_map[array_id], index: adr_id }, clone.res_type, Some(target_block_id), ); - let result_id = ctx.add_instruction(new_ins); - new_instructions.push(result_id); - inline_map.insert(i_id, result_id); + push_instruction(ctx, new_ins, &mut new_instructions, inline_map); } - Operation::Store(a) => { - let b = array_map[&a]; - let offset = ctx.mem.arrays[a as usize].adr - ctx.mem.arrays[b as usize].adr; - let index_type = ctx[new_left].get_type(); + Operation::Store { array_id, index, value } => { + let b = array_map[array_id]; + let offset = ctx.mem[*array_id].adr - ctx.mem[b].adr; + let index_type = ctx[*index].get_type(); let offset_id = ctx.get_or_create_const(FieldElement::from(offset as i128), index_type); - let adr_id = - ctx.new_instruction(offset_id, new_left, node::Operation::Add, index_type); - let new_ins = node::Instruction::new( - node::Operation::Store(array_map[&a]), - new_left, - adr_id, + + let add = node::Binary { operator: BinaryOp::Add, lhs: offset_id, rhs: *index }; + let adr_id = ctx.new_instruction(Operation::Binary(add), index_type); + let new_ins = Instruction::new( + Operation::Store { + array_id: array_map[array_id], + index: adr_id, + value: *value, + }, clone.res_type, Some(target_block_id), ); - let result_id = ctx.add_instruction(new_ins); - new_instructions.push(result_id); - inline_map.insert(i_id, result_id); + push_instruction(ctx, new_ins, &mut new_instructions, inline_map); } _ => { - let mut new_ins = node::Instruction::new( - clone.operator, - new_left, - new_right, - clone.res_type, - Some(target_block_id), - ); - if array_map.contains_key(&array_func_idx) { - let i_pointer = array_map[&array_func_idx]; - new_ins.res_type = node::ObjectType::Pointer(i_pointer); + let mut new_ins = new_cloned_instruction(clone, target_block_id); + + if let Some(id) = array_id { + if let Some(new_id) = array_map.get(&id) { + new_ins.res_type = node::ObjectType::Pointer(*new_id); + } } + optim::simplify(ctx, &mut new_ins); - let result_id; - let mut to_delete = false; - if new_ins.is_deleted { - result_id = new_ins.rhs; - if let std::collections::hash_map::Entry::Occupied(mut e) = - array_map.entry(array_func_idx) - { - if let node::ObjectType::Pointer(a) = ctx[result_id].get_type() { - //we now map the array to rhs array - e.insert(a); + + if let Mark::ReplaceWith(replacement) = new_ins.mark { + if let Some(id) = array_id { + if let Entry::Occupied(mut entry) = array_map.entry(id) { + if let ObjectType::Pointer(new_id) = ctx[replacement].get_type() { + //we now map the array to rhs array + entry.insert(new_id); + } } } - if new_ins.rhs == new_ins.id { - to_delete = true; + + if replacement != new_ins.id { + inline_map.insert(i_id, replacement); } } else { - result_id = ctx.add_instruction(new_ins); - new_instructions.push(result_id); - } - //ignore self-deleted instructions - if !to_delete { - inline_map.insert(i_id, result_id); + push_instruction(ctx, new_ins, &mut new_instructions, inline_map); } } } @@ -684,3 +637,22 @@ pub fn inline_in_block( next_block } + +fn new_cloned_instruction(original: Instruction, block: BlockId) -> Instruction { + let mut clone = Instruction::new(original.operation, original.res_type, Some(block)); + // Take the original's ID, it will be used to map it as a replacement in push_instruction later + clone.id = original.id; + clone +} + +fn push_instruction( + ctx: &mut SsaContext, + instruction: Instruction, + new_instructions: &mut Vec, + inline_map: &mut HashMap, +) { + let old_id = instruction.id; + let new_id = ctx.add_instruction(instruction); + new_instructions.push(new_id); + inline_map.insert(old_id, new_id); +} diff --git a/crates/noirc_evaluator/src/ssa/function.rs b/crates/noirc_evaluator/src/ssa/function.rs index c8be763265d..e66549043e9 100644 --- a/crates/noirc_evaluator/src/ssa/function.rs +++ b/crates/noirc_evaluator/src/ssa/function.rs @@ -56,16 +56,16 @@ impl SSAFunction { igen: &mut IRGenerator, env: &mut Environment, ) -> NodeId { - let call_id = igen.context.new_instruction( - NodeId::dummy(), - NodeId::dummy(), - node::Operation::Call(func), - node::ObjectType::NotAnObject, //TODO how to get the function return type? - ); - let ins_arguments = igen.expression_list_to_objects(env, arguments); - let call_ins = igen.context.get_mut_instruction(call_id); - call_ins.ins_arguments = ins_arguments; - call_id + let arguments = igen.expression_list_to_objects(env, arguments); + let call_instruction = igen + .context + .new_instruction(node::Operation::Call(func, arguments), node::ObjectType::NotAnObject); + + //TODO how to get the function return type? + igen.context.new_instruction( + node::Operation::Result { call_instruction, index: 0 }, + node::ObjectType::NotAnObject, + ) } pub fn get_mapped_value( @@ -85,7 +85,7 @@ impl SSAFunction { if let Some(c) = my_const { ctx.get_or_create_const(c.0, c.1) } else { - *inline_map.get(&node_id).unwrap() + inline_map[&node_id] } } else { NodeId::dummy() @@ -146,11 +146,7 @@ pub fn call_low_level( //when the function returns an array, we use ins.res_type(array) //else we map ins.id to the returned witness //Call instruction - igen.context.new_instruction_with_multiple_operands( - args, - node::Operation::Intrinsic(op), - result_type, - ) + igen.context.new_instruction(node::Operation::Intrinsic(op, args), result_type) } pub fn create_function( @@ -190,13 +186,8 @@ pub fn create_function( pub fn add_return_instruction(cfg: &mut SsaContext, last: Option) { let last_id = last.unwrap_or_else(NodeId::dummy); - let result = if last_id == NodeId::dummy() { Vec::new() } else { vec![last_id] }; + let result = if last_id == NodeId::dummy() { vec![] } else { vec![last_id] }; //Create return instruction based on the last statement of the function body - let result_id = cfg.new_instruction( - NodeId::dummy(), - NodeId::dummy(), - node::Operation::Ret, - node::ObjectType::NotAnObject, - ); - cfg.get_mut_instruction(result_id).ins_arguments = result; //n.b. should we keep the object type in the vector? + cfg.new_instruction(node::Operation::Return(result), node::ObjectType::NotAnObject); + //n.b. should we keep the object type in the vector? } diff --git a/crates/noirc_evaluator/src/ssa/integer.rs b/crates/noirc_evaluator/src/ssa/integer.rs index ecb29b574db..eba266926f2 100644 --- a/crates/noirc_evaluator/src/ssa/integer.rs +++ b/crates/noirc_evaluator/src/ssa/integer.rs @@ -2,10 +2,12 @@ use super::{ block::BlockId, //block, context::SsaContext, - node::{self, Instruction, Node, NodeId, NodeObj, Operation}, + mem::{ArrayId, Memory}, + node::{self, BinaryOp, Instruction, Mark, Node, NodeId, NodeObj, ObjectType, Operation}, optim, }; -use acvm::FieldElement; +use acvm::{acir::OPCODE, FieldElement}; +use noirc_frontend::util::vecmap; use num_bigint::BigUint; use num_traits::{One, Zero}; use std::convert::TryInto; @@ -19,50 +21,54 @@ pub fn short_integer_max_bit_size() -> u32 { } //Gets the maximum value of the instruction result -pub fn get_instruction_max( +fn get_instruction_max( ctx: &SsaContext, - ins: &node::Instruction, + ins: &Instruction, max_map: &mut HashMap, vmap: &HashMap, ) -> BigUint { - let r_max = get_obj_max_value(ctx, None, ins.rhs, max_map, vmap); - let l_max = get_obj_max_value(ctx, None, ins.lhs, max_map, vmap); - get_instruction_max_operand(ctx, ins, l_max, r_max, max_map, vmap) + ins.operation.for_each_id(|id| { + get_obj_max_value(ctx, id, max_map, vmap); + }); + get_instruction_max_operand(ctx, ins, max_map, vmap) } //Gets the maximum value of the instruction result using the provided operand maximum -pub fn get_instruction_max_operand( +fn get_instruction_max_operand( ctx: &SsaContext, - ins: &node::Instruction, - left_max: BigUint, - right_max: BigUint, + ins: &Instruction, max_map: &mut HashMap, vmap: &HashMap, ) -> BigUint { - match ins.operator { - node::Operation::Load(array) => get_load_max(ctx, ins.lhs, max_map, vmap, array), - node::Operation::Sub => { - //TODO uses interval analysis instead - if matches!(ins.res_type, node::ObjectType::Unsigned(_)) { - if let Some(lhs_const) = ctx.get_as_constant(ins.lhs) { - let lhs_big = BigUint::from_bytes_be(&lhs_const.to_bytes()); - if right_max <= lhs_big { - //todo unsigned - return lhs_big; + match &ins.operation { + Operation::Load { array_id, index } => get_load_max(ctx, *index, max_map, vmap, *array_id), + Operation::Binary(node::Binary { operator, lhs, rhs }) => { + match operator { + BinaryOp::Sub { .. } => { + //TODO uses interval analysis instead + if matches!(ins.res_type, ObjectType::Unsigned(_)) { + if let Some(lhs_const) = ctx.get_as_constant(*lhs) { + let lhs_big = BigUint::from_bytes_be(&lhs_const.to_bytes()); + if max_map[rhs] <= lhs_big { + //TODO unsigned + return lhs_big; + } + } } + get_max_value(ins, max_map) } + BinaryOp::Constrain(_) => { + //ContrainOp::Eq : + //TODO... we should update the max_map AFTER the truncate is processed (else it breaks it) + // let min = BigUint::min(left_max.clone(), right_max.clone()); + // max_map.insert(ins.lhs, min.clone()); + // max_map.insert(ins.rhs, min); + get_max_value(ins, max_map) + } + _ => get_max_value(ins, max_map), } - get_max_value(ins, left_max, right_max) - } - node::Operation::Constrain(_) => { - //ContrainOp::Eq : - //TODO... we should update the max_map AFTER the truncate is processed (else it breaks it) - // let min = BigUint::min(left_max.clone(), right_max.clone()); - // max_map.insert(ins.lhs, min.clone()); - // max_map.insert(ins.rhs, min); - get_max_value(ins, left_max, right_max) } - _ => get_max_value(ins, left_max, right_max), + _ => get_max_value(ins, max_map), } } @@ -70,9 +76,8 @@ pub fn get_instruction_max_operand( // or else we compute it. // we use the value array (get_current_value2) in order to handle truncate instructions // we need to do it because rust did not allow to modify the instruction in block_overflow.. -pub fn get_obj_max_value( +fn get_obj_max_value( ctx: &SsaContext, - obj: Option<&NodeObj>, id: NodeId, max_map: &mut HashMap, vmap: &HashMap, @@ -82,17 +87,13 @@ pub fn get_obj_max_value( return max_map[&id].clone(); } if id == NodeId::dummy() { + max_map.insert(id, BigUint::zero()); return BigUint::zero(); //a non-argument has no max } - let obj_ = obj.unwrap_or_else(|| &ctx[id]); + let obj = &ctx[id]; - let result = match obj_ { - NodeObj::Obj(v) => { - if v.size_in_bits() > 100 { - dbg!(&v); - } - (BigUint::one() << v.size_in_bits()) - BigUint::one() - } //TODO check for signed type + let result = match obj { + NodeObj::Obj(v) => (BigUint::one() << v.size_in_bits()) - BigUint::one(), //TODO check for signed type NodeObj::Instr(i) => get_instruction_max(ctx, i, max_map, vmap), NodeObj::Const(c) => c.value.clone(), //TODO panic for string constants }; @@ -101,7 +102,7 @@ pub fn get_obj_max_value( } //Creates a truncate instruction for obj_id -pub fn truncate( +fn truncate( ctx: &mut SsaContext, obj_id: NodeId, bit_size: u32, @@ -115,19 +116,21 @@ pub fn truncate( let v_max = &max_map[&obj_id]; if *v_max >= BigUint::one() << bit_size { - //TODO is this leaking some info???? - let rhs_bitsize = ctx.get_or_create_const( - FieldElement::from(bit_size as i128), - node::ObjectType::Unsigned(32), - ); + //TODO is max_bit_size leaking some info???? //Create a new truncate instruction '(idx): obj trunc bit_size' //set current value of obj to idx - let mut i = Instruction::new(Operation::Trunc, obj_id, rhs_bitsize, obj_type, None); + let max_bit_size = v_max.bits() as u32; + + let mut i = Instruction::new( + Operation::Truncate { value: obj_id, bit_size, max_bit_size }, + obj_type, + None, + ); + if i.res_name.ends_with("_t") { //TODO we should use %t so that we can check for this substring (% is not a valid char for a variable name) in the name and then write name%t[number+1] } i.res_name = obj_name + "_t"; - i.bit_size = v_max.bits() as u32; let i_id = ctx.add_instruction(i); max_map.insert(i_id, BigUint::from((1_u128 << bit_size) - 1)); Some(i_id) @@ -140,7 +143,7 @@ pub fn truncate( //Set the id and parent block of the truncate instruction //This is needed because the instruction is inserted into a block and not added in the current block like regular instructions //We also update the value array -pub fn fix_truncate( +fn fix_truncate( eval: &mut SsaContext, id: NodeId, prev_id: NodeId, @@ -160,23 +163,18 @@ fn add_to_truncate( bit_size: u32, to_truncate: &mut HashMap, max_map: &HashMap, -) -> BigUint { +) { let v_max = &max_map[&obj_id]; if *v_max >= BigUint::one() << bit_size { - if let Some(node::NodeObj::Const(_)) = &ctx.try_get_node(obj_id) { - return v_max.clone(); //a constant cannot be truncated, so we exit the function gracefully + if let Some(NodeObj::Const(_)) = &ctx.try_get_node(obj_id) { + return; //a constant cannot be truncated, so we exit the function gracefully } - let truncate_bits; - if to_truncate.contains_key(&obj_id) { - truncate_bits = u32::min(to_truncate[&obj_id], bit_size); - to_truncate.insert(obj_id, truncate_bits); - } else { - to_truncate.insert(obj_id, bit_size); - truncate_bits = bit_size; - } - return BigUint::from(truncate_bits - 1); + let truncate_bits = match to_truncate.get(&obj_id) { + Some(value) => u32::min(*value, bit_size), + None => bit_size, + }; + to_truncate.insert(obj_id, truncate_bits); } - v_max.clone() } //Truncate the 'to_truncate' list @@ -198,33 +196,6 @@ fn process_to_truncate( to_truncate.clear(); } -//Update right and left operands of the provided instruction -fn update_ins_parameters( - ctx: &mut SsaContext, - id: NodeId, - lhs: NodeId, - rhs: NodeId, - ins_arg: Vec, - max_value: Option, -) { - let mut ins = ctx.try_get_mut_instruction(id).unwrap(); - ins.lhs = lhs; - ins.rhs = rhs; - if let Some(max_v) = max_value { - ins.max_value = max_v; - } - ins.ins_arguments = ins_arg; -} - -fn update_ins(ctx: &mut SsaContext, id: NodeId, copy_from: &node::Instruction) { - let mut ins = ctx.try_get_mut_instruction(id).unwrap(); - ins.lhs = copy_from.lhs; - ins.rhs = copy_from.rhs; - ins.operator = copy_from.operator; - ins.max_value = copy_from.max_value.clone(); - ins.bit_size = copy_from.bit_size; -} - //Add required truncate instructions on all blocks pub fn overflow_strategy(ctx: &mut SsaContext) { let mut max_map: HashMap = HashMap::new(); @@ -233,7 +204,7 @@ pub fn overflow_strategy(ctx: &mut SsaContext) { } //implement overflow strategy following the dominator tree -pub fn tree_overflow( +fn tree_overflow( ctx: &mut SsaContext, b_idx: BlockId, max_map: &mut HashMap, @@ -247,7 +218,7 @@ pub fn tree_overflow( } //overflow strategy for one block -pub fn block_overflow( +fn block_overflow( ctx: &mut SsaContext, block_id: BlockId, max_map: &mut HashMap, @@ -257,158 +228,105 @@ pub fn block_overflow( //when it is over the field charac, or if the instruction requires it, then we insert truncate instructions // The instructions are insterted in a duplicate list( because of rust ownership..), which we use for // processing another cse round for the block because the truncates may be duplicated. - let mut instructions = Vec::new(); let mut new_list = Vec::new(); let mut truncate_map = HashMap::new(); - let mut modify_ins = None; - let mut trunc_size = FieldElement::zero(); - //RIA... - for iter in &ctx[block_id].instructions { - instructions.push((*ctx.try_get_instruction(*iter).unwrap()).clone()); - } + let instructions = + vecmap(&ctx[block_id].instructions, |id| ctx.try_get_instruction(*id).unwrap().clone()); + //since we process the block from the start, the block value map is not relevant let mut value_map = HashMap::new(); - let mut delete_ins = false; for mut ins in instructions { if matches!( - ins.operator, - node::Operation::Nop - | node::Operation::Call(_) - | node::Operation::Res - | node::Operation::Ret + ins.operation, + Operation::Nop | Operation::Call(..) | Operation::Result { .. } | Operation::Return(_) ) { //For now we skip completely functions from overflow; that means arguments are NOT truncated. //The reasoning is that this is handled by doing the overflow strategy after the function has been inlined continue; } - let mut ins_args = Vec::new(); - let mut i_lhs = ins.lhs; - let mut i_rhs = ins.rhs; + + ins.operation.map_id_mut(|id| { + let id = optim::propagate(ctx, id); + get_value_from_map(id, &value_map) + }); + //we propagate optimised loads - todo check if it is needed because there is cse at the end - if node::is_binary(ins.operator) { - //binary operation: - i_lhs = optim::propagate(ctx, ins.lhs); - i_rhs = optim::propagate(ctx, ins.rhs); - } //We retrieve get_current_value() in case a previous truncate has updated the value map - let r_id = get_value_from_map(i_rhs, &value_map); - let mut update_instruction = false; - if r_id != ins.rhs { - ins.rhs = r_id; - update_instruction = true; - } - let l_id = get_value_from_map(i_lhs, &value_map); - if l_id != ins.lhs { - ins.lhs = l_id; - update_instruction = true; - } - let r_obj = ctx.try_get_node(r_id); - let l_obj = ctx.try_get_node(l_id); - let r_max = get_obj_max_value(ctx, r_obj, r_id, max_map, &value_map); - let l_max = get_obj_max_value(ctx, l_obj, l_id, max_map, &value_map); - //insert required truncates, except for field type or dummy node - let to_truncate = ins.truncate_required(get_size_in_bits(l_obj), get_size_in_bits(r_obj)); - if to_truncate.0 && l_obj.is_some() && get_type(l_obj) != node::ObjectType::NativeField { - //adds a new truncate(lhs) instruction - add_to_truncate(ctx, l_id, get_size_in_bits(l_obj), &mut truncate_map, max_map); - } - if to_truncate.1 && r_obj.is_some() && get_type(r_obj) != node::ObjectType::NativeField { - //adds a new truncate(rhs) instruction - add_to_truncate(ctx, r_id, get_size_in_bits(r_obj), &mut truncate_map, max_map); - } - match ins.operator { - node::Operation::Load(_) => { + let should_truncate_ins = ins.truncate_required(ctx); + let ins_max_bits = get_instruction_max(ctx, &ins, max_map, &value_map).bits(); + let res_type = ins.res_type; + + let too_many_bits = ins_max_bits >= FieldElement::max_num_bits() as u64 + && res_type != ObjectType::NativeField; + + ins.operation.for_each_id(|id| { + get_obj_max_value(ctx, id, max_map, &value_map); + let arg = ctx.try_get_node(id); + let should_truncate_arg = + should_truncate_ins && arg.is_some() && get_type(arg) != ObjectType::NativeField; + + if should_truncate_arg || too_many_bits { + add_to_truncate(ctx, id, get_size_in_bits(arg), &mut truncate_map, max_map); + } + }); + + match ins.operation { + Operation::Load { index, .. } => { //TODO we use a local memory map for now but it should be used in arguments //for instance, the join block of a IF should merge the two memorymaps using the condition value - if let Some(adr) = super::mem::Memory::to_u32(ctx, ins.lhs) { + if let Some(adr) = Memory::to_u32(ctx, index) { if let Some(val) = memory_map.get(&adr) { //optimise static load - ins.is_deleted = true; - ins.rhs = *val; + ins.mark = Mark::ReplaceWith(*val); } } } - node::Operation::Store(_) => { - if let Some(adr) = super::mem::Memory::to_u32(ctx, ins.lhs) { + Operation::Store { index, value, .. } => { + if let Some(adr) = Memory::to_u32(ctx, index) { //optimise static store - memory_map.insert(adr, ins.rhs); - delete_ins = true; + memory_map.insert(adr, value); + + // Optimizing out stores is temporarily disabled due to errors in the + // pedersen example and '5' example + // mark = Mark::Deleted; } } - node::Operation::Cast => { - //TODO for now the types we support here are only all integer types (field, signed, unsigned, bool) - //so a cast would normally translate to a truncate. - //if res_type and lhs have the same bit size (in a large sens, which include field elements) - //then either they have the same type and should have been simplified - //or they don't have the same sign so we keep the cast operator - //if res_type is smaller than lhs bit size, we look if lhs can hold directly into res_type + Operation::Cast(value_id) => { + // TODO for now the types we support here are only all integer types (field, signed, unsigned, bool) + // so a cast would normally translate to a truncate. + // if res_type and lhs have the same bit size (in a large sense, which includes field elements) + // then either they have the same type and should have been simplified + // or they don't have the same sign so we keep the cast operator + // if res_type is smaller than lhs bit size, we look if lhs can hold directly into res_type // if not, we need to truncate lhs to a res_type. We modify directly the cast instruction into a truncate // in other cases we can keep the cast instruction // for instance if res_type is greater than lhs bit size, we need to truncate lhs to its bit size and use the truncate // result in the cast, but this is handled by the truncate_required // after this function, all cast instructions refer to casting lhs into a bigger (or equal) type - // anyother case has been transformed into the latter using truncates. - if ins.res_type == get_type(l_obj) { - ins.is_deleted = true; - ins.rhs = ins.lhs; - } - if ins.res_type.bits() < get_size_in_bits(l_obj) - && r_max.bits() as u32 > ins.res_type.bits() - { - //we need to truncate - update_instruction = true; - trunc_size = FieldElement::from(ins.res_type.bits() as i128); - let mut mod_ins = Instruction::new( - node::Operation::Trunc, - l_id, - l_id, - ins.res_type, - Some(ins.parent_block), - ); - mod_ins.bit_size = l_max.bits() as u32; - modify_ins = Some(mod_ins); - //TODO name for the instruction: modify_ins.res_name = l_obj."name"+"_t"; - //n.b. we do not update value map because we re-use the cast instruction + // any other case has been transformed into the latter using truncates. + let obj = ctx.try_get_node(value_id); + + if ins.res_type == get_type(obj) { + ins.mark = Mark::ReplaceWith(value_id); + } else { + let max = get_obj_max_value(ctx, value_id, max_map, &value_map); + let maxbits = max.bits() as u32; + + if ins.res_type.bits() < get_size_in_bits(obj) && maxbits > ins.res_type.bits() + { + //we need to truncate + ins.operation = Operation::Truncate { + value: value_id, + bit_size: ins.res_type.bits(), + max_bit_size: maxbits, + }; + } } } - // node::Operation::Call(_) => { - // for a in &ins.ins_arguments { - // add_to_truncate(igen, *a, igen[*a].get_type().bits(), &mut truncate_map, max_map); - // } - // } _ => (), } - let mut ins_max = get_instruction_max(ctx, &ins, max_map, &value_map); - if ins_max.bits() >= (FieldElement::max_num_bits() as u64) - && ins.res_type != node::ObjectType::NativeField - { - //let's truncate a and b: - //- insert truncate(lhs) dans la list des instructions - //- insert truncate(rhs) dans la list des instructions - //- update r_max et l_max - //n.b we could try to truncate only one of them, but then we should check if rhs==lhs. - let l_trunc_max = - add_to_truncate(ctx, l_id, get_size_in_bits(l_obj), &mut truncate_map, max_map); - let r_trunc_max = - add_to_truncate(ctx, r_id, get_size_in_bits(r_obj), &mut truncate_map, max_map); - ins_max = get_instruction_max_operand( - ctx, - &ins, - l_trunc_max.clone(), - r_trunc_max.clone(), - max_map, - &value_map, - ); - if ins_max.bits() >= FieldElement::max_num_bits().into() { - let message = format!( - "Require big int implementation, the bit size is too big for the field: {}, {}", - l_trunc_max.bits(), - r_trunc_max.bits() - ); - panic!("{}", message); - } - } + process_to_truncate( ctx, &mut new_list, @@ -417,52 +335,36 @@ pub fn block_overflow( block_id, &mut value_map, ); - if delete_ins { - delete_ins = false; - } else { - new_list.push(ins.id); - let l_new = get_value_from_map(l_id, &value_map); - let r_new = get_value_from_map(r_id, &value_map); - if l_new != l_id || r_new != r_id || is_sub(&ins.operator) { - update_instruction = true; - } - for a in &ins.ins_arguments { - let a_new = get_value_from_map(*a, &value_map); - if !update_instruction && *a != a_new { - update_instruction = true; - } - ins_args.push(a_new); - } + let id = match ins.mark { + Mark::None => ins.id, + Mark::Deleted => continue, + Mark::ReplaceWith(new_id) => new_id, + }; - if update_instruction { - let mut max_r_value = None; - if is_sub(&ins.operator) { - //for now we pass the max value to the instruction, we could also keep the max_map e.g in the block (or max in each nodeobj) - //sub operations require the max value to ensure it does not underflow - max_r_value = Some(max_map[&r_new].clone()); - //we may do that in future when the max_map becomes more used elsewhere (for other optim) - } - if let Some(modified_ins) = &mut modify_ins { - modified_ins.rhs = - ctx.get_or_create_const(trunc_size, node::ObjectType::Unsigned(32)); - modified_ins.lhs = l_new; - if let Some(max_v) = max_r_value { - modified_ins.max_value = max_v; - } - update_ins(ctx, ins.id, modified_ins); - } else { - update_ins_parameters(ctx, ins.id, l_new, r_new, ins_args, max_r_value); - } - } + new_list.push(id); + ins.operation.map_id_mut(|id| get_value_from_map(id, &value_map)); + + if let Operation::Binary(node::Binary { + rhs, + operator: BinaryOp::Sub { max_rhs_value } | BinaryOp::SafeSub { max_rhs_value }, + .. + }) = &mut ins.operation + { + //for now we pass the max value to the instruction, we could also keep the max_map e.g in the block (or max in each nodeobj) + //sub operations require the max value to ensure it does not underflow + *max_rhs_value = max_map[rhs].clone(); + //we may do that in future when the max_map becomes more used elsewhere (for other optim) } + + let old_ins = ctx.try_get_mut_instruction(id).unwrap(); + *old_ins = ins; } update_value_array(ctx, block_id, &value_map); //We run another round of CSE for the block in order to remove possible duplicated truncates, this will assign 'new_list' to the block instructions - let mut anchor = HashMap::new(); - optim::block_cse(ctx, block_id, &mut anchor, &mut new_list); + optim::cse_block(ctx, block_id, &mut new_list); } fn update_value_array(ctx: &mut SsaContext, block_id: BlockId, vmap: &HashMap) { @@ -477,7 +379,7 @@ pub fn get_value_from_map(id: NodeId, vmap: &HashMap) -> NodeId *vmap.get(&id).unwrap_or(&id) } -fn get_size_in_bits(obj: Option<&node::NodeObj>) -> u32 { +fn get_size_in_bits(obj: Option<&NodeObj>) -> u32 { if let Some(v) = obj { v.size_in_bits() } else { @@ -485,134 +387,140 @@ fn get_size_in_bits(obj: Option<&node::NodeObj>) -> u32 { } } -fn get_type(obj: Option<&node::NodeObj>) -> node::ObjectType { +fn get_type(obj: Option<&NodeObj>) -> ObjectType { if let Some(v) = obj { v.get_type() } else { - node::ObjectType::NotAnObject + ObjectType::NotAnObject } } -pub fn get_load_max( +fn get_load_max( ctx: &SsaContext, address: NodeId, max_map: &mut HashMap, vmap: &HashMap, - array: u32, - // obj_type: node::ObjectType, + array: ArrayId, + // obj_type: ObjectType, ) -> BigUint { if let Some(adr_as_const) = ctx.get_as_constant(address) { let adr: u32 = adr_as_const.to_u128().try_into().unwrap(); if let Some(&value) = ctx.mem.memory_map.get(&adr) { - return get_obj_max_value(ctx, None, value, max_map, vmap); + return get_obj_max_value(ctx, value, max_map, vmap); } }; - ctx.mem.arrays[array as usize].max.clone() //return array max - // return obj_type.max_size(); + ctx.mem[array].max.clone() //return array max + // return obj_type.max_size(); } //Returns the max value of an operation from an upper bound of left and right hand sides //Function is used to check for overflows over the field size, this is why we use BigUint. -pub fn get_max_value(ins: &Instruction, lhs_max: BigUint, rhs_max: BigUint) -> BigUint { - let max_value = match ins.operator { - Operation::Add => lhs_max + rhs_max, - Operation::SafeAdd => todo!(), - Operation::Sub => { - let r_mod = BigUint::one() << ins.res_type.bits(); - let mut k = &rhs_max / &r_mod; - if &rhs_max % &r_mod != BigUint::zero() { +fn get_max_value(ins: &Instruction, max_map: &mut HashMap) -> BigUint { + let max_value = match &ins.operation { + Operation::Binary(binary) => get_binary_max_value(binary, ins.res_type, max_map), + Operation::Not(_) => ins.res_type.max_size(), + //'a cast a' means we cast a into res_type of the instruction + Operation::Cast(value_id) => { + let type_max = ins.res_type.max_size(); + BigUint::min(max_map[value_id].clone(), type_max) + } + Operation::Truncate { value, max_bit_size, .. } => BigUint::min( + max_map[value].clone(), + BigUint::from(2_u32).pow(*max_bit_size) - BigUint::from(1_u32), + ), + Operation::Nop | Operation::Jne(..) | Operation::Jeq(..) | Operation::Jmp(_) => todo!(), + Operation::Phi { root, block_args } => { + let mut max = max_map[root].clone(); + for (id, _block) in block_args { + max = BigUint::max(max, max_map[id].clone()); + } + max + } + Operation::Load { .. } => unreachable!(), + Operation::Store { .. } => BigUint::zero(), + Operation::Call(..) => ins.res_type.max_size(), //TODO interval analysis but we also need to get the arguments (ins_arguments) + Operation::Return(_) => todo!(), + Operation::Result { .. } => ins.res_type.max_size(), + Operation::Intrinsic(opcode, _) => { + match opcode { + OPCODE::SHA256 + | OPCODE::Blake2s + | OPCODE::Pedersen + | OPCODE::FixedBaseScalarMul + | OPCODE::ToBits => BigUint::zero(), //pointers do not overflow + OPCODE::SchnorrVerify | OPCODE::EcdsaSecp256k1 => BigUint::one(), //verify returns 0 or 1 + _ => todo!(), + } + } + }; + + if ins.res_type == ObjectType::NativeField { + let field_max = BigUint::from_bytes_be(&FieldElement::one().neg().to_bytes()); + + //Native Field operations cannot overflow so they will not be truncated + if max_value >= field_max { + return field_max; + } + } + max_value +} + +fn get_binary_max_value( + binary: &node::Binary, + res_type: ObjectType, + max_map: &mut HashMap, +) -> BigUint { + let lhs_max = &max_map[&binary.lhs]; + let rhs_max = &max_map[&binary.rhs]; + + match &binary.operator { + BinaryOp::Add => lhs_max + rhs_max, + BinaryOp::SafeAdd => todo!(), + BinaryOp::Sub { .. } => { + let r_mod = BigUint::one() << res_type.bits(); + let mut k = rhs_max / &r_mod; + if rhs_max % &r_mod != BigUint::zero() { k += BigUint::one(); } - assert!(&k * &r_mod >= rhs_max); + assert!(&k * &r_mod >= *rhs_max); lhs_max + k * r_mod } - Operation::SafeSub => todo!(), - Operation::Mul => lhs_max * rhs_max, - Operation::SafeMul => todo!(), - Operation::Udiv => lhs_max, - Operation::Sdiv => todo!(), - Operation::Urem => rhs_max - BigUint::one(), - Operation::Srem => todo!(), - Operation::Div => todo!(), - Operation::Eq => BigUint::one(), - Operation::Ne => BigUint::one(), - Operation::Ugt => BigUint::one(), - Operation::Uge => BigUint::one(), - Operation::Ult => BigUint::one(), - Operation::Ule => BigUint::one(), - Operation::Sgt => BigUint::one(), - Operation::Sge => BigUint::one(), - Operation::Slt => BigUint::one(), - Operation::Sle => BigUint::one(), - Operation::Lt => BigUint::one(), - Operation::Gt => BigUint::one(), - Operation::Lte => BigUint::one(), - Operation::Gte => BigUint::one(), - Operation::And => { + BinaryOp::SafeSub { .. } => todo!(), + BinaryOp::Mul => lhs_max * rhs_max, + BinaryOp::SafeMul => todo!(), + BinaryOp::Udiv => lhs_max.clone(), + BinaryOp::Sdiv => todo!(), + BinaryOp::Urem => rhs_max - BigUint::one(), + BinaryOp::Srem => todo!(), + BinaryOp::Div => todo!(), + BinaryOp::Eq => BigUint::one(), + BinaryOp::Ne => BigUint::one(), + BinaryOp::Ult => BigUint::one(), + BinaryOp::Ule => BigUint::one(), + BinaryOp::Slt => BigUint::one(), + BinaryOp::Sle => BigUint::one(), + BinaryOp::Lt => BigUint::one(), + BinaryOp::Lte => BigUint::one(), + BinaryOp::And => { BigUint::from(2_u32).pow(u64::min(lhs_max.bits(), rhs_max.bits()) as u32) - BigUint::one() } - Operation::Xor | Operation::Or => { + BinaryOp::Or | BinaryOp::Xor => { BigUint::from(2_u32).pow(u64::max(lhs_max.bits(), rhs_max.bits()) as u32) - BigUint::one() } - Operation::Not => ins.res_type.max_size(), - Operation::Shr => BigUint::min( + BinaryOp::Assign => rhs_max.clone(), + BinaryOp::Constrain(_) => BigUint::zero(), + BinaryOp::Shr => BigUint::min( BigUint::from(2_u32).pow((lhs_max.bits() + 1) as u32) - BigUint::one(), - ins.res_type.max_size(), + res_type.max_size(), ), - Operation::Shl => { + BinaryOp::Shl => { if lhs_max.bits() >= 1 { BigUint::from(2_u32).pow((lhs_max.bits() - 1) as u32) - BigUint::one() } else { BigUint::zero() } } - //'a cast a' means we cast a into res_type of the instruction - Operation::Cast => { - let type_max = ins.res_type.max_size(); - BigUint::min(lhs_max, type_max) - } - Operation::Trunc => BigUint::min( - lhs_max, - BigUint::from(2_u32).pow(rhs_max.try_into().unwrap()) - BigUint::from(1_u32), - ), - //'a = b': a and b must be of same type. - Operation::Ass => rhs_max, - Operation::Nop | Operation::Jne | Operation::Jeq | Operation::Jmp => todo!(), - Operation::Phi => BigUint::max(lhs_max, rhs_max), //TODO operands are in phi_arguments, not lhs/rhs!! - Operation::Constrain(_) => BigUint::zero(), //min(lhs_max, rhs_max), - Operation::Load(_) => { - unreachable!(); - } - Operation::Store(_) => BigUint::zero(), - Operation::Call(_) => ins.res_type.max_size(), //TODO interval analysis but we also need to get the arguments (ins_arguments) - Operation::Ret => todo!(), - Operation::Res => todo!(), - Operation::Intrinsic(opcode) => { - match opcode { - acvm::acir::OPCODE::SHA256 - | acvm::acir::OPCODE::Blake2s - | acvm::acir::OPCODE::Pedersen - | acvm::acir::OPCODE::FixedBaseScalarMul - | acvm::acir::OPCODE::ToBits => BigUint::zero(), //pointers do not overflow - acvm::acir::OPCODE::SchnorrVerify | acvm::acir::OPCODE::EcdsaSecp256k1 => { - BigUint::one() - } //verify returns 0 or 1 - _ => todo!(), - } - } - }; - if ins.res_type == node::ObjectType::NativeField { - //Native Field operations cannot overflow so they will not be truncated - if max_value >= BigUint::from_bytes_be(&FieldElement::one().neg().to_bytes()) { - return BigUint::from_bytes_be(&FieldElement::one().neg().to_bytes()); - } } - max_value -} - -//indicates if the operation is a substraction, we need to check them for underflow -pub fn is_sub(operator: &Operation) -> bool { - matches!(operator, Operation::Sub | Operation::SafeSub) } diff --git a/crates/noirc_evaluator/src/ssa/mem.rs b/crates/noirc_evaluator/src/ssa/mem.rs index 7221586e3b3..5cbdf1bad88 100644 --- a/crates/noirc_evaluator/src/ssa/mem.rs +++ b/crates/noirc_evaluator/src/ssa/mem.rs @@ -13,13 +13,17 @@ use std::convert::TryInto; #[derive(Default)] pub struct Memory { - pub arrays: Vec, + arrays: Vec, pub last_adr: u32, //last address in 'memory' pub memory_map: HashMap, //maps memory adress to expression } +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ArrayId(u32); + #[derive(Debug, Clone)] pub struct MemArray { + pub id: ArrayId, pub element_type: node::ObjectType, //type of elements pub values: Vec, pub name: String, @@ -30,7 +34,7 @@ pub struct MemArray { } impl MemArray { - pub fn set_witness(&mut self, array: &Array) { + fn set_witness(&mut self, array: &Array) { for object in &array.contents { if let Some(w) = node::get_witness_from_object(object) { self.values.push(w.into()); @@ -39,9 +43,16 @@ impl MemArray { assert!(self.values.is_empty() || self.values.len() == self.len.try_into().unwrap()); } - pub fn new(definition: DefinitionId, name: &str, of: node::ObjectType, len: u32) -> MemArray { + fn new( + id: ArrayId, + definition: DefinitionId, + name: &str, + of: node::ObjectType, + len: u32, + ) -> MemArray { assert!(len > 0); MemArray { + id, element_type: of, name: name.to_string(), values: Vec::new(), @@ -54,16 +65,18 @@ impl MemArray { } impl Memory { - pub fn find_array(&self, definition: &Option) -> Option<&MemArray> { - definition.and_then(|def| self.arrays.iter().find(|a| a.def == def)) + pub fn find_array(&self, definition: DefinitionId) -> Option<&MemArray> { + self.arrays.iter().find(|a| a.def == definition) } - pub fn get_array_index(&self, array: &MemArray) -> Option { - self.arrays.iter().position(|x| x.adr == array.adr).map(|p| p as u32) + /// Retrieves the ArrayId of the last array in Memory. + /// Panics if self does not contain at least 1 array. + pub fn last_id(&self) -> ArrayId { + ArrayId(self.arrays.len() as u32 - 1) } //dereference a pointer - pub fn deref(ctx: &SsaContext, id: NodeId) -> Option { + pub fn deref(ctx: &SsaContext, id: NodeId) -> Option { ctx.try_get_node(id).and_then(|var| match var.get_type() { node::ObjectType::Pointer(a) => Some(a), _ => None, @@ -78,19 +91,25 @@ impl Memory { arr_name: &str, ) -> &MemArray { let len = u32::try_from(array.length).unwrap(); - self.create_new_array(len, el_type, arr_name); - let mem_array = self.arrays.last_mut().unwrap(); + let id = self.create_new_array(len, el_type, arr_name); + let mem_array = &mut self[id]; mem_array.set_witness(array); mem_array.def = definition; - self.arrays.last().unwrap() + mem_array } - pub fn create_new_array(&mut self, len: u32, el_type: node::ObjectType, arr_name: &str) -> u32 { - let mut new_array = MemArray::new(DefinitionId::dummy_id(), arr_name, el_type, len); + pub fn create_new_array( + &mut self, + len: u32, + el_type: node::ObjectType, + arr_name: &str, + ) -> ArrayId { + let id = ArrayId(self.arrays.len() as u32); + let mut new_array = MemArray::new(id, DefinitionId::dummy_id(), arr_name, el_type, len); new_array.adr = self.last_adr; self.arrays.push(new_array); self.last_adr += len; - (self.arrays.len() - 1) as u32 + id } pub fn as_u32(value: FieldElement) -> u32 { @@ -111,3 +130,17 @@ impl Memory { None //Not a constant object } } + +impl std::ops::Index for Memory { + type Output = MemArray; + + fn index(&self, index: ArrayId) -> &Self::Output { + &self.arrays[index.0 as usize] + } +} + +impl std::ops::IndexMut for Memory { + fn index_mut(&mut self, index: ArrayId) -> &mut Self::Output { + &mut self.arrays[index.0 as usize] + } +} diff --git a/crates/noirc_evaluator/src/ssa/node.rs b/crates/noirc_evaluator/src/ssa/node.rs index 8091426449c..bb27dd3569e 100644 --- a/crates/noirc_evaluator/src/ssa/node.rs +++ b/crates/noirc_evaluator/src/ssa/node.rs @@ -1,5 +1,4 @@ use std::convert::TryInto; -use std::ops::Add; use acvm::acir::native_types::Witness; use acvm::acir::OPCODE; @@ -7,16 +6,17 @@ use acvm::FieldElement; use arena; use noirc_frontend::hir_def::expr::HirBinaryOpKind; use noirc_frontend::node_interner::DefinitionId; +use noirc_frontend::util::vecmap; use noirc_frontend::{Signedness, Type}; use num_bigint::BigUint; -use num_traits::One; +use num_traits::{FromPrimitive, One}; use crate::object::Object; -use num_traits::identities::Zero; -use std::ops::Mul; +use std::ops::{Add, Mul, Sub}; use super::block::BlockId; use super::context::SsaContext; +use super::mem::ArrayId; pub trait Node: std::fmt::Display { fn get_type(&self) -> ObjectType; @@ -29,6 +29,7 @@ impl std::fmt::Display for Variable { write!(f, "{}", self.name) } } + impl std::fmt::Display for NodeObj { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { @@ -38,6 +39,7 @@ impl std::fmt::Display for NodeObj { } } } + impl std::fmt::Display for Constant { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", self.value) @@ -184,7 +186,7 @@ pub enum ObjectType { Boolean, Unsigned(u32), //bit size Signed(u32), //bit size - Pointer(u32), //array index + Pointer(ArrayId), //custom(u32), //user-defined struct, u32 refers to the id of the type in...?todo //TODO big_int //TODO floats @@ -287,21 +289,18 @@ impl ObjectType { #[derive(Clone, Debug)] pub struct Instruction { pub id: NodeId, - pub operator: Operation, - pub rhs: NodeId, - pub lhs: NodeId, + pub operation: Operation, pub res_type: ObjectType, //result type - //prev,next: should have been a double linked list so that we can easily remove an instruction during optimisation phases pub parent_block: BlockId, - pub is_deleted: bool, pub res_name: String, - pub bit_size: u32, //TODO only for the truncate instruction...: bits size of the max value of the lhs.. a merger avec ci dessous!!!TODO - pub max_value: BigUint, //TODO only for sub instruction: max value of the rhs - - //temp: todo phi subtype - pub phi_arguments: Vec<(NodeId, BlockId)>, + pub mark: Mark, +} - pub ins_arguments: Vec, +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Mark { + None, + Deleted, + ReplaceWith(NodeId), } impl std::fmt::Display for Instruction { @@ -334,13 +333,25 @@ impl NodeEval { NodeEval::Const(_, _) => None, } } + + pub fn from_id(ctx: &SsaContext, id: NodeId) -> NodeEval { + match &ctx[id] { + NodeObj::Const(c) => { + let value = FieldElement::from_be_bytes_reduce(&c.value.to_bytes_be()); + NodeEval::Const(value, c.get_type()) + } + _ => NodeEval::VarOrInstruction(id), + } + } + + fn from_u128(value: u128, typ: ObjectType) -> NodeEval { + NodeEval::Const(value.into(), typ) + } } impl Instruction { pub fn new( op_code: Operation, - lhs: NodeId, - rhs: NodeId, r_type: ObjectType, parent_block: Option, ) -> Instruction { @@ -349,643 +360,775 @@ impl Instruction { Instruction { id, - operator: op_code, - lhs, - rhs, + operation: op_code, res_type: r_type, res_name: String::new(), - is_deleted: false, parent_block: p_block, - bit_size: 0, - max_value: BigUint::zero(), - phi_arguments: Vec::new(), - ins_arguments: Vec::new(), + mark: Mark::None, } } - pub fn get_arguments(&self) -> Vec { - if !self.ins_arguments.is_empty() { - return self.ins_arguments.clone(); - } - if self.lhs != NodeId::dummy() && self.rhs != NodeId::dummy() { - return vec![self.lhs, self.rhs]; - } - if self.lhs != NodeId::dummy() { - return vec![self.lhs]; - } - if self.rhs != NodeId::dummy() { - return vec![self.rhs]; - } - Vec::new() - } - - //indicates whether the left and/or right operand of the instruction is required to be truncated to its bit-width - pub fn truncate_required(&self, lhs_bits: u32, rhs_bits: u32) -> (bool, bool) { - match self.operator { - Operation::Add => (false, false), - Operation::SafeAdd => (false, false), - Operation::Sub => (false, false), - Operation::SafeSub => (false, false), - Operation::Mul => (false, false), - Operation::SafeMul => (false, false), - Operation::Udiv => (true, true), - Operation::Sdiv => (true, true), - Operation::Urem => (true, true), - Operation::Srem => (true, true), - Operation::Div => (false, false), - Operation::Eq => (true, true), - Operation::Ne => (true, true), - Operation::Ugt => (true, true), - Operation::Uge => (true, true), - Operation::Ult => (true, true), - Operation::Ule => (true, true), - Operation::Sgt => (true, true), - Operation::Sge => (true, true), - Operation::Slt => (true, true), - Operation::Sle => (true, true), - Operation::Lt => (true, true), - Operation::Gt => (true, true), - Operation::Lte => (true, true), - Operation::Gte => (true, true), - Operation::And | Operation::Not | Operation::Or => (true, true), - Operation::Xor | Operation::Shr | Operation::Shl => (true, true), - Operation::Cast => { - if self.res_type.bits() > lhs_bits { - return (true, false); - } - (false, false) - } - Operation::Ass => { - assert!(lhs_bits == rhs_bits); - (false, false) - } - Operation::Trunc | Operation::Phi => (false, false), - Operation::Nop | Operation::Jne | Operation::Jeq | Operation::Jmp => (false, false), - Operation::Constrain(_) => (true, true), - Operation::Load(_) | Operation::Store(_) => (false, false), - Operation::Intrinsic(_) => (true, true), //TODO to check - Operation::Call(_) => (false, false), //return values are in the return statment, should we truncate function arguments? probably but not lhs and rhs anyways. - Operation::Ret => (true, false), - Operation::Res => (false, false), + /// Indicates whether the left and/or right operand of the instruction is required to be truncated to its bit-width + pub fn truncate_required(&self, ctx: &SsaContext) -> bool { + match &self.operation { + Operation::Binary(binary) => binary.truncate_required(), + Operation::Not(..) => true, + Operation::Cast(value_id) => { + let obj = ctx.try_get_node(*value_id); + let bits = obj.map_or(0, |obj| obj.size_in_bits()); + self.res_type.bits() > bits + } + Operation::Truncate { .. } | Operation::Phi { .. } => false, + Operation::Nop | Operation::Jne(..) | Operation::Jeq(..) | Operation::Jmp(..) => false, + Operation::Load { .. } | Operation::Store { .. } => false, + Operation::Intrinsic(_, _) => true, //TODO to check + Operation::Call(_, _) => false, //return values are in the return statment, should we truncate function arguments? probably but not lhs and rhs anyways. + Operation::Return(_) => true, + Operation::Result { .. } => false, } } - //Returns the field element as i128 and the bit size of the constant node - pub fn get_const_value(c: FieldElement, ctype: ObjectType) -> (u128, u32) { - match ctype { - ObjectType::Boolean => (if c.is_zero() { 0 } else { 1 }, 1), - ObjectType::NativeField => { - (c.to_u128(), 256) //TODO: handle elements that do not fit in 128 bits + pub fn evaluate(&self, ctx: &SsaContext) -> NodeEval { + self.evaluate_with(ctx, NodeEval::from_id) + } + + //Evaluate the instruction value when its operands are constant (constant folding) + pub fn evaluate_with(&self, ctx: &SsaContext, mut eval_fn: F) -> NodeEval + where + F: FnMut(&SsaContext, NodeId) -> NodeEval, + { + match &self.operation { + Operation::Binary(binary) => { + return binary.evaluate(ctx, self.id, self.res_type, eval_fn) + } + Operation::Cast(value) => { + if let Some(l_const) = eval_fn(ctx, *value).into_const_value() { + if self.res_type == ObjectType::NativeField { + return NodeEval::Const(l_const, self.res_type); + } else if let Some(l_const) = l_const.try_into_u128() { + return NodeEval::Const( + FieldElement::from(l_const % (1_u128 << self.res_type.bits())), + self.res_type, + ); + } + } + } + Operation::Not(value) => { + let obj = eval_fn(ctx, *value).into_const_value(); + if let Some(l_const) = obj.and_then(|field| field.try_into_u128()) { + return NodeEval::Const(FieldElement::from(!l_const), self.res_type); + } } - ObjectType::Signed(b) | ObjectType::Unsigned(b) => { - assert!(b < 128); //we do not support integers bigger than 128 bits for now. - (c.to_u128(), b) - } //TODO check how to handle signed integers - _ => todo!(), + Operation::Phi { .. } => (), //Phi are simplified by simply_phi() later on; they must not be simplified here + _ => (), } + NodeEval::VarOrInstruction(self.id) } - pub fn node_evaluate(n: &NodeEval) -> (bool, Option, u32) { - match n { - &NodeEval::Const(c, t) => { - let cv = Instruction::get_const_value(c, t); - (c.is_zero(), Some(cv.0), cv.1) + // Simplifies trivial Phi instructions by returning: + // None, if the instruction is unreachable or in the root block and can be safely deleted + // Some(id), if the instruction can be replaced by the node id + // Some(ins_id), if the instruction is not trivial + pub fn simplify_phi(ins_id: NodeId, phi_arguments: &[(NodeId, BlockId)]) -> Option { + let mut same = None; + for op in phi_arguments { + if Some(op.0) == same || op.0 == ins_id { + continue; + } + if same.is_some() { + //no simplification + return Some(ins_id); } - _ => (false, None, 0), + + same = Some(op.0); } + //if same.is_none() => unreachable phi or in root block, can be replaced by ins.lhs (i.e the root) then. + same } - //Evaluate the instruction value when its operands are constant (constant folding) - pub fn evaluate(&self, lhs: &NodeEval, rhs: &NodeEval) -> NodeEval { - //let mut l_sign = false; //TODO - let (l_is_zero, l_constant, l_bsize) = Instruction::node_evaluate(lhs); - let (r_is_zero, r_constant, r_bsize) = Instruction::node_evaluate(rhs); - let r_is_const = r_constant.is_some(); - let l_is_const = l_constant.is_some(); - - match self.operator { - Operation::Add | Operation::SafeAdd => { - if r_is_zero { - return *lhs; - } else if l_is_zero { - return *rhs; - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - //constant folding - if l_bsize == 256 { - //NO modulo for field elements - May be we should have a different opcode? - if let (NodeEval::Const(a, _), NodeEval::Const(b, _)) = (lhs, rhs) { - let res_value = a.add(*b); - return NodeEval::Const(res_value, self.res_type); + pub fn is_deleted(&self) -> bool { + !matches!(self.mark, Mark::None) + } + + pub fn standard_form(&mut self) { + if let Operation::Binary(binary) = &mut self.operation { + if let BinaryOp::Constrain(op) = &binary.operator { + match op { + ConstrainOp::Eq => { + if binary.lhs == binary.rhs { + self.operation = Operation::Nop; + return; } - unreachable!(); } - assert!(l_bsize == r_bsize); - let res_value = (l_const + r_const) % (1_u128 << l_bsize) as u128; - return NodeEval::Const(FieldElement::from(res_value), self.res_type); + ConstrainOp::Neq => assert_ne!(binary.lhs, binary.rhs), + } + } + + if binary.operator.is_commutative() && binary.rhs < binary.lhs { + std::mem::swap(&mut binary.rhs, &mut binary.lhs); + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub enum ConstrainOp { + Eq, + Neq, + //Cmp... +} + +//adapted from LLVM IR +#[allow(dead_code)] //Some enums are not used yet, allow dead_code should be removed once they are all implemented. +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum Operation { + Binary(Binary), + Cast(NodeId), //convert type + Truncate { value: NodeId, bit_size: u32, max_bit_size: u32 }, //truncate + + Not(NodeId), //(!) Bitwise Not + + //control flow + Jne(NodeId, BlockId), //jump on not equal + Jeq(NodeId, BlockId), //jump on equal + Jmp(BlockId), //unconditional jump + Phi { root: NodeId, block_args: Vec<(NodeId, BlockId)> }, + + Call(noirc_frontend::node_interner::FuncId, Vec), //Call a function + Return(Vec), //Return value(s) from a function block + Result { call_instruction: NodeId, index: u32 }, //Get result index n from a function call + + //memory + Load { array_id: ArrayId, index: NodeId }, + Store { array_id: ArrayId, index: NodeId, value: NodeId }, + + Intrinsic(OPCODE, Vec), //Custom implementation of usefull primitives which are more performant with Aztec backend + + Nop, // no op +} + +#[derive(Copy, Clone, Hash, PartialEq, Eq)] +pub enum Opcode { + Add, + SafeAdd, + Sub, + SafeSub, + Mul, + SafeMul, + Udiv, + Sdiv, + Urem, + Srem, + Div, + Eq, + Ne, + Ult, + Ule, + Slt, + Sle, + Lt, + Lte, + And, + Or, + Xor, + Shl, + Shr, + Assign, + Constrain(ConstrainOp), + + Cast, //convert type + Truncate, //truncate + Not, //(!) Bitwise Not + + //control flow + Jne, //jump on not equal + Jeq, //jump on equal + Jmp, //unconditional jump + Phi, + + Call(noirc_frontend::node_interner::FuncId), //Call a function + Return, //Return value(s) from a function block + Results, //Get result(s) from a function call + + //memory + Load(ArrayId), + Store(ArrayId), + Intrinsic(OPCODE), //Custom implementation of usefull primitives which are more performant with Aztec backend + Nop, // no op +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct Binary { + pub lhs: NodeId, + pub rhs: NodeId, + pub operator: BinaryOp, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum BinaryOp { + Add, //(+) + #[allow(dead_code)] + SafeAdd, //(+) safe addition + Sub { + max_rhs_value: BigUint, + }, //(-) + #[allow(dead_code)] + SafeSub { + max_rhs_value: BigUint, + }, //(-) safe subtraction + Mul, //(*) + #[allow(dead_code)] + SafeMul, //(*) safe multiplication + Udiv, //(/) unsigned division + Sdiv, //(/) signed division + #[allow(dead_code)] + Urem, //(%) modulo; remainder of unsigned division + #[allow(dead_code)] + Srem, //(%) remainder of signed division + Div, //(/) field division + Eq, //(==) equal + Ne, //(!=) not equal + Ult, //(<) unsigned less than + Ule, //(<=) unsigned less or equal + Slt, //(<) signed less than + Sle, //(<=) signed less or equal + Lt, //(<) field less + Lte, //(<=) field less or equal + And, //(&) Bitwise And + Or, //(|) Bitwise Or + Xor, //(^) Bitwise Xor + Shl, //(<<) Shift left + Shr, //(<<) Shift right + + Assign, + Constrain(ConstrainOp), //write gates enforcing the ContrainOp to be true +} + +impl Binary { + fn new(operator: BinaryOp, lhs: NodeId, rhs: NodeId) -> Binary { + Binary { operator, lhs, rhs } + } + + pub fn from_hir( + op_kind: HirBinaryOpKind, + op_type: ObjectType, + lhs: NodeId, + rhs: NodeId, + ) -> Binary { + let operator = match op_kind { + HirBinaryOpKind::Add => BinaryOp::Add, + HirBinaryOpKind::Subtract => { + BinaryOp::Sub { max_rhs_value: BigUint::from_u8(0).unwrap() } + } + HirBinaryOpKind::Multiply => BinaryOp::Mul, + HirBinaryOpKind::Equal => BinaryOp::Eq, + HirBinaryOpKind::NotEqual => BinaryOp::Ne, + HirBinaryOpKind::And => BinaryOp::And, + HirBinaryOpKind::Or => BinaryOp::Or, + HirBinaryOpKind::Xor => BinaryOp::Xor, + HirBinaryOpKind::Divide => { + let num_type: NumericType = op_type.into(); + match num_type { + NumericType::Signed(_) => BinaryOp::Sdiv, + NumericType::Unsigned(_) => BinaryOp::Udiv, + NumericType::NativeField => BinaryOp::Div, + } + } + HirBinaryOpKind::Less => { + let num_type: NumericType = op_type.into(); + match num_type { + NumericType::Signed(_) => BinaryOp::Slt, + NumericType::Unsigned(_) => BinaryOp::Ult, + NumericType::NativeField => BinaryOp::Lt, + } + } + HirBinaryOpKind::LessEqual => { + let num_type: NumericType = op_type.into(); + match num_type { + NumericType::Signed(_) => BinaryOp::Sle, + NumericType::Unsigned(_) => BinaryOp::Ule, + NumericType::NativeField => BinaryOp::Lte, + } + } + HirBinaryOpKind::Greater => { + let num_type: NumericType = op_type.into(); + match num_type { + NumericType::Signed(_) => return Binary::new(BinaryOp::Slt, rhs, lhs), + NumericType::Unsigned(_) => return Binary::new(BinaryOp::Ult, rhs, lhs), + NumericType::NativeField => return Binary::new(BinaryOp::Lt, rhs, lhs), + } + } + HirBinaryOpKind::GreaterEqual => { + let num_type: NumericType = op_type.into(); + match num_type { + NumericType::Signed(_) => return Binary::new(BinaryOp::Sle, rhs, lhs), + NumericType::Unsigned(_) => return Binary::new(BinaryOp::Ule, rhs, lhs), + NumericType::NativeField => return Binary::new(BinaryOp::Lte, rhs, lhs), + } + } + HirBinaryOpKind::Assign => BinaryOp::Assign, + HirBinaryOpKind::Shl => BinaryOp::Shl, + HirBinaryOpKind::Shr => BinaryOp::Shr, + }; + + Binary::new(operator, lhs, rhs) + } + + fn evaluate( + &self, + ctx: &SsaContext, + id: NodeId, + res_type: ObjectType, + mut eval_fn: F, + ) -> NodeEval + where + F: FnMut(&SsaContext, NodeId) -> NodeEval, + { + let l_eval = eval_fn(ctx, self.lhs); + let r_eval = eval_fn(ctx, self.rhs); + + let lhs = l_eval.into_const_value(); + let rhs = r_eval.into_const_value(); + + let l_is_zero = lhs.map_or(false, |x| x.is_zero()); + let r_is_zero = rhs.map_or(false, |x| x.is_zero()); + + match &self.operator { + BinaryOp::Add | BinaryOp::SafeAdd => { + if l_is_zero { + return r_eval; + } else if r_is_zero { + return l_eval; + } + + if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return wrapping(lhs, rhs, res_type, u128::add, Add::add); } //if only one is const, we could try to do constant propagation but this will be handled by the arithmetization step anyways //so it is probably not worth it. //same for x+x vs 2*x } - Operation::Sub | Operation::SafeSub => { + BinaryOp::Sub { .. } | BinaryOp::SafeSub { .. } => { if r_is_zero { - return *lhs; + return l_eval; } if self.lhs == self.rhs { - return NodeEval::Const(FieldElement::zero(), self.res_type); + return NodeEval::from_u128(0, res_type); } //constant folding - if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - if l_bsize == 256 { - //NO modulo for field elements - May be we should have a different opcode? - if let (NodeEval::Const(a, _), NodeEval::Const(b, _)) = (lhs, rhs) { - let res_value = a.add(-*b); - return NodeEval::Const(res_value, self.res_type); - } - unreachable!(); - } - //if l_constant.is_some() && r_constant.is_some() { - assert!(l_bsize == r_bsize); - - let res_value = - l_const.overflowing_sub(r_const).0 % (1_u128 << l_bsize) as u128; - return NodeEval::Const(FieldElement::from(res_value), self.res_type); + if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return wrapping(lhs, rhs, res_type, u128::wrapping_sub, Sub::sub); } } - Operation::Mul | Operation::SafeMul => { - if r_is_zero { - return *rhs; - } else if l_is_zero { - return *lhs; - } else if l_is_const && l_constant.unwrap() == 1 { - return *rhs; - } else if r_is_const && r_constant.unwrap() == 1 { - return *lhs; - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - //constant folding - if l_bsize == 256 { - //NO modulo for field elements - May be we should have a different opcode? - if let (NodeEval::Const(a, _), NodeEval::Const(b, _)) = (lhs, rhs) { - let res_value = a.mul(*b); - return NodeEval::Const(res_value, self.res_type); - } - unreachable!(); - } + BinaryOp::Mul | BinaryOp::SafeMul => { + let l_is_one = lhs.map_or(false, |x| x.is_one()); + let r_is_one = rhs.map_or(false, |x| x.is_one()); - assert!(l_bsize == r_bsize); - let res_value = (l_const * r_const) % (1_u128 << l_bsize) as u128; - return NodeEval::Const(FieldElement::from(res_value), self.res_type); + if l_is_zero || r_is_one { + return l_eval; + } else if r_is_zero || l_is_one { + return r_eval; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return wrapping(lhs, rhs, res_type, u128::mul, Mul::mul); } //if only one is const, we could try to do constant propagation but this will be handled by the arithmetization step anyways //so it is probably not worth it. } - Operation::Udiv | Operation::Sdiv | Operation::Div => { + BinaryOp::Udiv | BinaryOp::Sdiv | BinaryOp::Div => { if r_is_zero { todo!("Panic - division by zero"); } else if l_is_zero { - return *lhs; //TODO should we ensure rhs != 0 ??? + return l_eval; } //constant folding - TODO - else if l_constant.is_some() && r_constant.is_some() { - todo!(); - } else if r_constant.is_some() { + else if lhs.is_some() && rhs.is_some() { + todo!("Constant folding for division"); + } else if rhs.is_some() { //same as lhs*1/r - todo!(""); + todo!("Constant folding for division rhs"); //return (Some(self.lhs), None, None); } } - Operation::Urem | Operation::Srem => { + BinaryOp::Urem | BinaryOp::Srem => { if r_is_zero { todo!("Panic - division by zero"); } else if l_is_zero { - return *lhs; //TODO what is the correct result? and should we ensure rhs != 0 ??? + return l_eval; //TODO what is the correct result? } //constant folding - TODO - else if l_constant.is_some() && r_constant.is_some() { - todo!("divide l_constant/r_constant but take sign into account"); + else if lhs.is_some() && rhs.is_some() { + todo!("divide lhs/rhs but take sign into account"); } } - Operation::Uge => { + BinaryOp::Ult => { if r_is_zero { return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); //n.b we assume the type of lhs and rhs is unsigned because of the opcode, we could also verify this - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - assert!(l_bsize < 256); //comparisons are not implemented for field elements - let res = - if l_const >= r_const { FieldElement::one() } else { FieldElement::zero() }; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + assert_ne!(res_type, ObjectType::NativeField); //comparisons are not implemented for field elements + let res = if lhs < rhs { FieldElement::one() } else { FieldElement::zero() }; return NodeEval::Const(res, ObjectType::Boolean); } } - Operation::Ult => { - if r_is_zero { - return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); + BinaryOp::Ule => { + if l_is_zero { + return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); //n.b we assume the type of lhs and rhs is unsigned because of the opcode, we could also verify this - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - assert!(l_bsize < 256); //comparisons are not implemented for field elements - let res = - if l_const < r_const { FieldElement::one() } else { FieldElement::zero() }; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + assert_ne!(res_type, ObjectType::NativeField); //comparisons are not implemented for field elements + let res = if lhs <= rhs { FieldElement::one() } else { FieldElement::zero() }; return NodeEval::Const(res, ObjectType::Boolean); } } - Operation::Ule => { - if l_is_zero { - return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); + BinaryOp::Slt => (), + BinaryOp::Sle => (), + BinaryOp::Lt => { + if r_is_zero { + return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); //n.b we assume the type of lhs and rhs is unsigned because of the opcode, we could also verify this - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - assert!(l_bsize < 256); //comparisons are not implemented for field elements - let res = - if l_const <= r_const { FieldElement::one() } else { FieldElement::zero() }; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + let res = if lhs < rhs { FieldElement::one() } else { FieldElement::zero() }; return NodeEval::Const(res, ObjectType::Boolean); } } - Operation::Ugt => { + BinaryOp::Lte => { if l_is_zero { - return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); - // u<0 is false for unsigned u - //n.b we assume the type of lhs and rhs is unsigned because of the opcode, we could also verify this - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - assert!(l_bsize < 256); //comparisons are not implemented for field elements - let res = - if l_const > r_const { FieldElement::one() } else { FieldElement::zero() }; + return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); + //n.b we assume the type of lhs and rhs is unsigned because of the opcode, we could also verify this + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + let res = if lhs <= rhs { FieldElement::one() } else { FieldElement::zero() }; return NodeEval::Const(res, ObjectType::Boolean); } } - Operation::Eq => { + BinaryOp::Eq => { if self.lhs == self.rhs { return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - if l_bsize == 256 { - if let (NodeEval::Const(a, _), NodeEval::Const(b, _)) = (lhs, rhs) { - if a == b { - return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); - } else { - return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); - } - } - unreachable!(); - } - if l_const == r_const { + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + if lhs == rhs { return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); } else { return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); } } } - Operation::Ne => { - if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - if l_bsize == 256 { - if let (NodeEval::Const(a, _), NodeEval::Const(b, _)) = (lhs, rhs) { - if a != b { - return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); - } else { - return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); - } - } - unreachable!(); - } - if l_const != r_const { + BinaryOp::Ne => { + if self.lhs == self.rhs { + return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + if lhs != rhs { return NodeEval::Const(FieldElement::one(), ObjectType::Boolean); } else { return NodeEval::Const(FieldElement::zero(), ObjectType::Boolean); } } } - Operation::And => { + BinaryOp::And => { //Bitwise AND - if l_is_zero { - return *lhs; + if l_is_zero || self.lhs == self.rhs { + return l_eval; } else if r_is_zero { - return *rhs; - } else if self.lhs == self.rhs { - return *lhs; - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - return NodeEval::Const(FieldElement::from(l_const & r_const), self.res_type); + return r_eval; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return NodeEval::from_u128(lhs.to_u128() & rhs.to_u128(), res_type); } //TODO if boolean and not zero, also checks this is correct for field elements } - Operation::Or => { + BinaryOp::Or => { //Bitwise OR - if l_is_zero { - return *rhs; + if l_is_zero || self.lhs == self.rhs { + return r_eval; } else if r_is_zero { - return *lhs; - } else if self.lhs == self.rhs { - return *rhs; - } else if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - return NodeEval::Const(FieldElement::from(l_const | r_const), self.res_type); + return l_eval; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return NodeEval::from_u128(lhs.to_u128() | rhs.to_u128(), res_type); } //TODO if boolean and not zero, also checks this is correct for field elements } - - Operation::Not => { - if let Some(l_const) = l_constant { - return NodeEval::Const(FieldElement::from(!l_const), self.res_type); - } - } - Operation::Xor => { + BinaryOp::Xor => { if self.lhs == self.rhs { - return NodeEval::Const(FieldElement::zero(), self.res_type); - } - if l_is_zero { - return *rhs; - } - if r_is_zero { - return *lhs; - } - if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - return NodeEval::Const(FieldElement::from(l_const ^ r_const), self.res_type); + return NodeEval::Const(FieldElement::zero(), res_type); + } else if l_is_zero { + return r_eval; + } else if r_is_zero { + return l_eval; + } else if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return NodeEval::from_u128(lhs.to_u128() ^ rhs.to_u128(), res_type); } - //TODO handle case when l_const is one (or r_const is one) by generating 'not rhs' instruction (or 'not lhs' instruction) + //TODO handle case when lhs is one (or rhs is one) by generating 'not rhs' instruction (or 'not lhs' instruction) } - Operation::Shl => { + BinaryOp::Shl => { if l_is_zero { - return *lhs; + return l_eval; } if r_is_zero { - return *lhs; + return l_eval; } - if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - return NodeEval::Const(FieldElement::from(l_const << r_const), self.res_type); + if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return NodeEval::from_u128(lhs.to_u128() << rhs.to_u128(), res_type); } } - Operation::Shr => { + BinaryOp::Shr => { if l_is_zero { - return *lhs; + return l_eval; } if r_is_zero { - return *lhs; - } - if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { - return NodeEval::Const(FieldElement::from(l_const >> r_const), self.res_type); + return l_eval; } - } - Operation::Cast => { - if let Some(l_const) = l_constant { - if self.res_type == ObjectType::NativeField { - return NodeEval::Const(FieldElement::from(l_const), self.res_type); - } - return NodeEval::Const( - FieldElement::from(l_const % (1_u128 << self.res_type.bits())), - self.res_type, - ); + if let (Some(lhs), Some(rhs)) = (lhs, rhs) { + return NodeEval::from_u128(lhs.to_u128() >> rhs.to_u128(), res_type); } } - Operation::Constrain(op) => { - if let (Some(l_const), Some(r_const)) = (l_constant, r_constant) { + BinaryOp::Constrain(op) => { + if let (Some(lhs), Some(rhs)) = (lhs, rhs) { match op { - ConstrainOp::Eq => assert_eq!(l_const, r_const), - ConstrainOp::Neq => assert_ne!(l_const, r_const), + ConstrainOp::Eq => assert_eq!(lhs, rhs), + ConstrainOp::Neq => assert_ne!(lhs, rhs), } //we can delete the instruction return NodeEval::VarOrInstruction(NodeId::dummy()); } } - Operation::Phi => (), //Phi are simplified by simply_phi() later on; they must not be simplified here - _ => (), + BinaryOp::Assign => (), } - NodeEval::VarOrInstruction(self.id) + NodeEval::VarOrInstruction(id) } - // Simplifies trivial Phi instructions by returning: - // None, if the instruction is unreachable or in the root block and can be safely deleted - // Some(id), if the instruction can be replaced by the node id - // Some(ins_id), if the instruction is not trivial - pub fn simplify_phi(ins_id: NodeId, phi_arguments: &[(NodeId, BlockId)]) -> Option { - let mut same = None; - for op in phi_arguments { - if Some(op.0) == same || op.0 == ins_id { - continue; - } - if same.is_some() { - //no simplification - return Some(ins_id); - } - - same = Some(op.0); + fn truncate_required(&self) -> bool { + match &self.operator { + BinaryOp::Add => false, + BinaryOp::SafeAdd => false, + BinaryOp::Sub { .. } => false, + BinaryOp::SafeSub { .. } => false, + BinaryOp::Mul => false, + BinaryOp::SafeMul => false, + BinaryOp::Udiv => true, + BinaryOp::Sdiv => true, + BinaryOp::Urem => true, + BinaryOp::Srem => true, + BinaryOp::Div => false, + BinaryOp::Eq => true, + BinaryOp::Ne => true, + BinaryOp::Ult => true, + BinaryOp::Ule => true, + BinaryOp::Slt => true, + BinaryOp::Sle => true, + BinaryOp::Lt => true, + BinaryOp::Lte => true, + BinaryOp::And => true, + BinaryOp::Or => true, + BinaryOp::Xor => true, + BinaryOp::Constrain(..) => true, + BinaryOp::Assign => false, + BinaryOp::Shl => true, + BinaryOp::Shr => true, } - //if same.is_none() => unreachable phi or in root block, can be replaced by ins.lhs (i.e the root) then. - same } - pub fn standard_form(&mut self) { - match self.operator { - //convert > into < and <= into >= - Operation::Ugt => { - std::mem::swap(&mut self.rhs, &mut self.lhs); - self.operator = Operation::Ult - } - Operation::Ule => { - std::mem::swap(&mut self.rhs, &mut self.lhs); - self.operator = Operation::Uge - } - Operation::Sgt => { - std::mem::swap(&mut self.rhs, &mut self.lhs); - self.operator = Operation::Slt - } - Operation::Sle => { - std::mem::swap(&mut self.rhs, &mut self.lhs); - self.operator = Operation::Sge - } - Operation::Gt => { - std::mem::swap(&mut self.rhs, &mut self.lhs); - self.operator = Operation::Lt - } - Operation::Lte => { - std::mem::swap(&mut self.rhs, &mut self.lhs); - self.operator = Operation::Gte - } - Operation::Constrain(op) => match op { - ConstrainOp::Eq => { - if self.rhs == self.lhs { - self.rhs = self.id; - self.is_deleted = true; - self.operator = Operation::Nop; - } - } - ConstrainOp::Neq => assert!(self.rhs != self.lhs), - }, - _ => (), - } - if is_commutative(self.operator) && self.rhs < self.lhs { - std::mem::swap(&mut self.rhs, &mut self.lhs); + pub fn opcode(&self) -> Opcode { + match &self.operator { + BinaryOp::Add => Opcode::Add, + BinaryOp::SafeAdd => Opcode::SafeAdd, + BinaryOp::Sub { .. } => Opcode::Sub, + BinaryOp::SafeSub { .. } => Opcode::SafeSub, + BinaryOp::Mul => Opcode::Mul, + BinaryOp::SafeMul => Opcode::SafeMul, + BinaryOp::Udiv => Opcode::Udiv, + BinaryOp::Sdiv => Opcode::Sdiv, + BinaryOp::Urem => Opcode::Urem, + BinaryOp::Srem => Opcode::Srem, + BinaryOp::Div => Opcode::Div, + BinaryOp::Eq => Opcode::Eq, + BinaryOp::Ne => Opcode::Ne, + BinaryOp::Ult => Opcode::Ult, + BinaryOp::Ule => Opcode::Ule, + BinaryOp::Slt => Opcode::Slt, + BinaryOp::Sle => Opcode::Sle, + BinaryOp::Lt => Opcode::Lt, + BinaryOp::Lte => Opcode::Lte, + BinaryOp::And => Opcode::And, + BinaryOp::Or => Opcode::Or, + BinaryOp::Xor => Opcode::Xor, + BinaryOp::Shl => Opcode::Shl, + BinaryOp::Shr => Opcode::Shr, + BinaryOp::Assign => Opcode::Assign, + BinaryOp::Constrain(op) => Opcode::Constrain(*op), } } } -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub enum ConstrainOp { - Eq, - Neq, - //Cmp... +/// Perform the given numeric operation and modulo the result by the max value for the given bitcount +/// if the res_type is not a NativeField. +fn wrapping( + lhs: FieldElement, + rhs: FieldElement, + res_type: ObjectType, + u128_op: impl FnOnce(u128, u128) -> u128, + field_op: impl FnOnce(FieldElement, FieldElement) -> FieldElement, +) -> NodeEval { + if res_type != ObjectType::NativeField { + let mut x = u128_op(lhs.to_u128(), rhs.to_u128()); + x %= 1_u128 << res_type.bits(); + NodeEval::from_u128(x, res_type) + } else { + NodeEval::Const(field_op(lhs, rhs), res_type) + } } -//adapted from LLVM IR -#[allow(dead_code)] //Some enums are not used yet, allow dead_code should be removed once they are all implemented. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub enum Operation { - Add, //(+) - SafeAdd, //(+) safe addition - Sub, //(-) - SafeSub, //(-) safe subtraction - Mul, //(*) - SafeMul, //(*) safe multiplication - Udiv, //(/) unsigned division - Sdiv, //(/) signed division - Urem, //(%) modulo; remainder of unsigned division - Srem, //(%) remainder of signed division - Div, //(/) field division - Eq, //(==) equal - Ne, //(!=) not equal - Ugt, //(>) unsigned greater than - Uge, //(>=) unsigned greater or equal - Ult, //(<) unsigned less than - Ule, //(<=) unsigned less or equal - Sgt, //(>) signed greater than - Sge, //(>=) signed greater or equal - Slt, //(<) signed less than - Sle, //(<=) signed less or equal - Lt, //(<) field less - Gt, //(>) field greater - Lte, //(<=) field less or equal - Gte, //(<=) field greater or equal - And, //(&) Bitwise And - Not, //(!) Bitwise Not - Or, //(|) Bitwise Or - Xor, //(^) Bitwise Xor - Shl, //(<<) Shift left - Shr, //(<<) Shift right - Cast, //convert type - Ass, //assignement - Trunc, //truncate - - //control flow - Jne, //jump on not equal - Jeq, //jump on equal - Jmp, //unconditional jump - Phi, - //memory - Load(u32), - Store(u32), - //Functions - Intrinsic(OPCODE), //Custom implementation of usefull primitives which are more performant with Aztec backend - Call(noirc_frontend::node_interner::FuncId), //Call a function - Ret, //Return value from a function block - Res, //get result from a function call - - Constrain(ConstrainOp), //write gates enforcing the ContrainOp to be true - - Nop, // no op -} +impl Operation { + pub fn binary(op: BinaryOp, lhs: NodeId, rhs: NodeId) -> Self { + Operation::Binary(Binary::new(op, lhs, rhs)) + } -pub fn is_commutative(op_code: Operation) -> bool { - matches!( - op_code, - Operation::Add - | Operation::SafeAdd - | Operation::Mul - | Operation::SafeMul - | Operation::And - | Operation::Or - | Operation::Xor - | Operation::Constrain(ConstrainOp::Eq) - | Operation::Constrain(ConstrainOp::Neq) - ) -} - -pub fn is_binary(op_code: Operation) -> bool { - matches!( - op_code, - Operation::Add - | Operation::SafeAdd - | Operation::Sub - | Operation::SafeSub - | Operation::Mul - | Operation::SafeMul - | Operation::Udiv - | Operation::Sdiv - | Operation::Urem - | Operation::Srem - | Operation::Div - | Operation::Eq - | Operation::Ne - | Operation::Ugt - | Operation::Uge - | Operation::Ult - | Operation::Ule - | Operation::Sgt - | Operation::Sge - | Operation::Slt - | Operation::Sle - | Operation::Lt - | Operation::Gt - | Operation::Lte - | Operation::Gte - | Operation::And - | Operation::Or - | Operation::Xor - | Operation::Trunc - | Operation::Constrain(_) - ) - - //For the record: Operation::not | Operation::cast => false | Operation::ass | Operation::trunc | Operation::nop - // | Operation::jne | Operation::jeq | Operation::jmp | Operation::phi | Operation::load | Operation::store - // | Operation::Call(_) | Operation::Res | Operation::Ret -} - -pub fn to_operation(op_kind: HirBinaryOpKind, op_type: ObjectType) -> Operation { - match op_kind { - HirBinaryOpKind::Add => Operation::Add, - HirBinaryOpKind::Subtract => Operation::Sub, - HirBinaryOpKind::Multiply => Operation::Mul, - HirBinaryOpKind::Equal => Operation::Eq, - HirBinaryOpKind::NotEqual => Operation::Ne, - HirBinaryOpKind::And => Operation::And, - HirBinaryOpKind::Or => Operation::Or, - HirBinaryOpKind::Xor => Operation::Xor, - HirBinaryOpKind::Shl => Operation::Shl, - HirBinaryOpKind::Shr => Operation::Shr, - HirBinaryOpKind::Divide => { - let num_type: NumericType = op_type.into(); - match num_type { - NumericType::Signed(_) => Operation::Sdiv, - NumericType::Unsigned(_) => Operation::Udiv, - NumericType::NativeField => Operation::Div, + pub fn map_id(&self, mut f: impl FnMut(NodeId) -> NodeId) -> Operation { + use Operation::*; + match self { + Binary(self::Binary { lhs, rhs, operator }) => { + Binary(self::Binary { lhs: f(*lhs), rhs: f(*rhs), operator: operator.clone() }) } - } - HirBinaryOpKind::Less => { - let num_type: NumericType = op_type.into(); - match num_type { - NumericType::Signed(_) => Operation::Slt, - NumericType::Unsigned(_) => Operation::Ult, - NumericType::NativeField => Operation::Lt, + Cast(value) => Cast(f(*value)), + Truncate { value, bit_size, max_bit_size } => { + Truncate { value: f(*value), bit_size: *bit_size, max_bit_size: *max_bit_size } } - } - HirBinaryOpKind::Greater => { - let num_type: NumericType = op_type.into(); - match num_type { - NumericType::Signed(_) => Operation::Sgt, - NumericType::Unsigned(_) => Operation::Ugt, - NumericType::NativeField => Operation::Gt, + Not(id) => Not(f(*id)), + Jne(id, block) => Jne(f(*id), *block), + Jeq(id, block) => Jeq(f(*id), *block), + Jmp(block) => Jmp(*block), + Phi { root, block_args } => Phi { + root: f(*root), + block_args: vecmap(block_args, |(id, block)| (f(*id), *block)), + }, + Load { array_id: array, index } => Load { array_id: *array, index: f(*index) }, + Store { array_id: array, index, value } => { + Store { array_id: *array, index: f(*index), value: f(*value) } + } + Intrinsic(i, args) => Intrinsic(*i, vecmap(args.iter().copied(), f)), + Nop => Nop, + Call(func_id, args) => Call(*func_id, vecmap(args.iter().copied(), f)), + Return(values) => Return(vecmap(values.iter().copied(), f)), + Result { call_instruction, index } => { + Result { call_instruction: f(*call_instruction), index: *index } } } - HirBinaryOpKind::LessEqual => { - let num_type: NumericType = op_type.into(); - match num_type { - NumericType::Signed(_) => Operation::Sle, - NumericType::Unsigned(_) => Operation::Ult, - NumericType::NativeField => Operation::Lte, + } + + /// Mutate each contained NodeId in place using the given function f + pub fn map_id_mut(&mut self, mut f: impl FnMut(NodeId) -> NodeId) { + use Operation::*; + match self { + Binary(self::Binary { lhs, rhs, .. }) => { + *lhs = f(*lhs); + *rhs = f(*rhs); + } + Cast(value) => *value = f(*value), + Truncate { value, .. } => *value = f(*value), + Not(id) => *id = f(*id), + Jne(id, _) => *id = f(*id), + Jeq(id, _) => *id = f(*id), + Jmp(_) => (), + Phi { root, block_args } => { + f(*root); + for (id, _block) in block_args { + *id = f(*id); + } + } + Load { index, .. } => *index = f(*index), + Store { index, value, .. } => { + *index = f(*index); + *value = f(*value); + } + Intrinsic(_, args) => { + for arg in args { + *arg = f(*arg); + } + } + Nop => (), + Call(_, args) => { + for arg in args { + *arg = f(*arg); + } + } + Return(values) => { + for value in values { + *value = f(*value); + } + } + Result { call_instruction, index: _ } => { + *call_instruction = f(*call_instruction); } } - HirBinaryOpKind::GreaterEqual => { - let num_type: NumericType = op_type.into(); - match num_type { - NumericType::Signed(_) => Operation::Sge, - NumericType::Unsigned(_) => Operation::Uge, - NumericType::NativeField => Operation::Gte, + } + + /// This is the same as map_id but doesn't return a new Operation + pub fn for_each_id(&self, mut f: impl FnMut(NodeId)) { + use Operation::*; + match self { + Binary(self::Binary { lhs, rhs, .. }) => { + f(*lhs); + f(*rhs); + } + Cast(value) => f(*value), + Truncate { value, .. } => f(*value), + Not(id) => f(*id), + Jne(id, _) => f(*id), + Jeq(id, _) => f(*id), + Jmp(_) => (), + Phi { root, block_args } => { + f(*root); + for (id, _block) in block_args { + f(*id); + } + } + Load { index, .. } => f(*index), + Store { index, value, .. } => { + f(*index); + f(*value); } + Intrinsic(_, args) => args.iter().copied().for_each(f), + Nop => (), + Call(_, args) => args.iter().copied().for_each(f), + Return(values) => values.iter().copied().for_each(f), + Result { call_instruction, .. } => { + f(*call_instruction); + } + } + } + + pub fn opcode(&self) -> Opcode { + match self { + Operation::Binary(binary) => binary.opcode(), + Operation::Cast(_) => Opcode::Cast, + Operation::Truncate { .. } => Opcode::Truncate, + Operation::Not(_) => Opcode::Not, + Operation::Jne(_, _) => Opcode::Jne, + Operation::Jeq(_, _) => Opcode::Jeq, + Operation::Jmp(_) => Opcode::Jmp, + Operation::Phi { .. } => Opcode::Phi, + Operation::Call(id, _) => Opcode::Call(*id), + Operation::Return(_) => Opcode::Return, + Operation::Result { .. } => Opcode::Results, + Operation::Load { array_id, .. } => Opcode::Load(*array_id), + Operation::Store { array_id, .. } => Opcode::Store(*array_id), + Operation::Intrinsic(opcode, _) => Opcode::Intrinsic(*opcode), + Operation::Nop => Opcode::Nop, } - HirBinaryOpKind::Assign => Operation::Ass, - HirBinaryOpKind::MemberAccess => todo!(), + } +} + +impl BinaryOp { + fn is_commutative(&self) -> bool { + matches!( + self, + BinaryOp::Add + | BinaryOp::SafeAdd + | BinaryOp::Mul + | BinaryOp::SafeMul + | BinaryOp::And + | BinaryOp::Or + | BinaryOp::Xor + // This isn't a match-all pattern in case more ops are ever added + // that aren't commutative + | BinaryOp::Constrain(ConstrainOp::Eq | ConstrainOp::Neq) + ) } } diff --git a/crates/noirc_evaluator/src/ssa/optim.rs b/crates/noirc_evaluator/src/ssa/optim.rs index b5ec0d8be9d..b5b18cda349 100644 --- a/crates/noirc_evaluator/src/ssa/optim.rs +++ b/crates/noirc_evaluator/src/ssa/optim.rs @@ -1,225 +1,137 @@ +use acvm::FieldElement; + use super::{ acir_gen::InternalVar, block::BlockId, context::SsaContext, - mem, - node::{self, Instruction, Node, NodeEval, NodeId, NodeObj, Operation}, + mem::Memory, + node::{ + self, BinaryOp, ConstrainOp, Instruction, Mark, Node, NodeEval, NodeId, ObjectType, Opcode, + Operation, Variable, + }, +}; +use std::{ + borrow::Cow, + collections::{HashMap, VecDeque}, }; -use acvm::FieldElement; -use std::collections::{HashMap, VecDeque}; - -//returns the NodeObj index of a NodeEval object -//if NodeEval is a constant, it may creates a new NodeObj corresponding to the constant value -pub fn to_index(ctx: &mut SsaContext, obj: NodeEval) -> NodeId { - match obj { - NodeEval::Const(c, t) => ctx.get_or_create_const(c, t), - NodeEval::VarOrInstruction(i) => i, - } -} - -// If NodeEval refers to a constant NodeObj, we return a constant NodeEval -pub fn to_const(ctx: &SsaContext, obj: NodeEval) -> NodeEval { - match obj { - NodeEval::Const(_, _) => obj, - NodeEval::VarOrInstruction(i) => { - if let Some(NodeObj::Const(c)) = ctx.try_get_node(i) { - return NodeEval::Const( - FieldElement::from_be_bytes_reduce(&c.value.to_bytes_be()), - c.get_type(), - ); - } - obj - } - } -} // Performs constant folding, arithmetic simplifications and move to standard form +// Modifies ins.mark with whether the instruction should be deleted, replaced, or neither pub fn simplify(ctx: &mut SsaContext, ins: &mut node::Instruction) { //1. constant folding - let l_eval = to_const(ctx, NodeEval::VarOrInstruction(ins.lhs)); - let r_eval = to_const(ctx, NodeEval::VarOrInstruction(ins.rhs)); - let idx = match ins.evaluate(&l_eval, &r_eval) { + let new_id = match ins.evaluate(ctx) { NodeEval::Const(c, t) => ctx.get_or_create_const(c, t), NodeEval::VarOrInstruction(i) => i, }; - if idx != ins.id { - ins.is_deleted = true; - ins.rhs = idx; - if idx == NodeId::dummy() { - ins.operator = node::Operation::Nop; - } + + if new_id != ins.id { + use Mark::*; + ins.mark = if new_id == NodeId::dummy() { Deleted } else { ReplaceWith(new_id) }; return; } //2. standard form ins.standard_form(); - match ins.operator { - node::Operation::Cast => { - if let Some(lhs_obj) = ctx.try_get_node(ins.lhs) { - if lhs_obj.get_type() == ins.res_type { - ins.is_deleted = true; - ins.rhs = ins.lhs; + match ins.operation { + Operation::Cast(value_id) => { + if let Some(value) = ctx.try_get_node(value_id) { + if value.get_type() == ins.res_type { + ins.mark = Mark::ReplaceWith(value_id); return; } } } - // node::Operation::Gte => { - // //a>=b <=> Not(a match op { - node::ConstrainOp::Eq => { - if let (Some(a), Some(b)) = - (mem::Memory::deref(ctx, ins.lhs), mem::Memory::deref(ctx, ins.rhs)) - { - if a == b { - ins.is_deleted = true; - ins.operator = node::Operation::Nop; - } + Operation::Binary(node::Binary { operator: BinaryOp::Constrain(op), lhs, rhs }) => { + match (op, Memory::deref(ctx, lhs), Memory::deref(ctx, rhs)) { + (ConstrainOp::Eq, Some(lhs), Some(rhs)) if lhs == rhs => { + ins.mark = Mark::Deleted; } - } - node::ConstrainOp::Neq => { - if let (Some(a), Some(b)) = - (mem::Memory::deref(ctx, ins.lhs), mem::Memory::deref(ctx, ins.rhs)) - { - assert!(a != b); + (ConstrainOp::Neq, Some(lhs), Some(rhs)) => { + assert_ne!(lhs, rhs); } + _ => (), } - }, + } _ => (), } //3. left-overs (it requires &mut ctx) - if let NodeEval::Const(r_const, r_type) = r_eval { - match ins.operator { - node::Operation::Udiv => { - //TODO handle other bitsize, not only u32!! - ins.rhs = ctx.get_or_create_const( - FieldElement::from((1_u32 / (r_const.to_u128() as u32)) as i128), - r_type, - ); - ins.operator = node::Operation::Mul - } - node::Operation::Sdiv => { - //TODO handle other bitsize, not only i32!! - ins.rhs = ctx.get_or_create_const( - FieldElement::from((1_i32 / (r_const.to_u128() as i32)) as i128), - r_type, - ); - ins.operator = node::Operation::Mul - } - node::Operation::Div => { - ins.rhs = ctx.get_or_create_const(r_const.inverse(), r_type); - ins.operator = node::Operation::Mul - } - node::Operation::Xor => { - if !r_const.is_zero() { - ins.operator = node::Operation::Not; - return; + if let Operation::Binary(binary) = &mut ins.operation { + if let NodeEval::Const(r_const, r_type) = NodeEval::from_id(ctx, binary.rhs) { + match &binary.operator { + BinaryOp::Div => { + binary.rhs = ctx.get_or_create_const(r_const.inverse(), r_type); + binary.operator = BinaryOp::Mul; } - } - node::Operation::Shl => { - ins.operator = node::Operation::Mul; - //todo checks that 2^rhs does not overflow - ins.rhs = ctx.get_or_create_const(FieldElement::from(2_i128).pow(&r_const), r_type); - return; - } - node::Operation::Shr => { - if !matches!(ins.res_type, node::ObjectType::Unsigned(_)) { - todo!("Right shift is only implemented for unsigned integers"); + BinaryOp::Shl => { + binary.operator = BinaryOp::Mul; + //todo checks that 2^rhs does not overflow + binary.rhs = + ctx.get_or_create_const(FieldElement::from(2_i128).pow(&r_const), r_type); + } + BinaryOp::Shr => { + if !matches!(ins.res_type, node::ObjectType::Unsigned(_)) { + todo!("Right shift is only implemented for unsigned integers"); + } + binary.operator = BinaryOp::Udiv; + //todo checks that 2^rhs does not overflow + binary.rhs = + ctx.get_or_create_const(FieldElement::from(2_i128).pow(&r_const), r_type); } - ins.operator = node::Operation::Udiv; - //todo checks that 2^rhs does not overflow - ins.rhs = ctx.get_or_create_const(FieldElement::from(2_i128).pow(&r_const), r_type); - return; + _ => (), } - _ => (), } } - if let NodeEval::Const(l_const, _) = l_eval { - if !l_const.is_zero() && ins.operator == node::Operation::Xor { - ins.operator = node::Operation::Not; - ins.lhs = ins.rhs; - } - if let NodeEval::Const(r_const, _) = r_eval { - if let Operation::Intrinsic(op) = ins.operator { - ins.rhs = evaluate_intrinsic(ctx, op, l_const, r_const); - ins.is_deleted = true; - } + + if let Operation::Intrinsic(opcode, args) = &ins.operation { + let args = args + .iter() + .map(|arg| NodeEval::from_id(ctx, *arg).into_const_value().map(|f| f.to_u128())); + + if let Some(args) = args.collect() { + ins.mark = Mark::ReplaceWith(evaluate_intrinsic(ctx, *opcode, args)); } } } -fn evaluate_intrinsic( - irgen: &mut SsaContext, - op: acvm::acir::OPCODE, - lhs: FieldElement, - rhs: FieldElement, -) -> NodeId { +fn evaluate_intrinsic(ctx: &mut SsaContext, op: acvm::acir::OPCODE, args: Vec) -> NodeId { match op { acvm::acir::OPCODE::ToBits => { - let lhs_int = lhs.to_u128(); - let rhs_int = rhs.to_u128() as u32; - let a = - irgen.mem.create_new_array(rhs_int, node::ObjectType::Unsigned(1), &String::new()); - let pointer = node::Variable { + let bit_count = args[1] as u32; + let array_id = ctx.mem.create_new_array(bit_count, ObjectType::Unsigned(1), ""); + let pointer = Variable { id: NodeId::dummy(), - obj_type: node::ObjectType::Pointer(a), + obj_type: ObjectType::Pointer(array_id), root: None, name: String::new(), def: None, witness: None, - parent_block: irgen.current_block, + parent_block: ctx.current_block, }; - for i in 0..rhs_int { - if lhs_int & (1 << i) != 0 { - irgen.mem.arrays[a as usize] - .values - .push(InternalVar::from(FieldElement::one())); + + for i in 0..bit_count { + if args[0] & (1 << i) != 0 { + ctx.mem[array_id].values.push(InternalVar::from(FieldElement::one())); } else { - irgen.mem.arrays[a as usize] - .values - .push(InternalVar::from(FieldElement::zero())); + ctx.mem[array_id].values.push(InternalVar::from(FieldElement::zero())); } } - irgen.add_variable(pointer, None) + + ctx.add_variable(pointer, None) } _ => todo!(), } } - ////////////////////CSE//////////////////////////////////////// pub fn find_similar_instruction( igen: &SsaContext, - lhs: NodeId, - rhs: NodeId, - prev_ins: &VecDeque, -) -> Option { - for iter in prev_ins { - if let Some(ins) = igen.try_get_instruction(*iter) { - if ins.lhs == lhs && ins.rhs == rhs { - return Some(*iter); - } - } - } - None -} - -pub fn find_similar_instruction_with_multiple_arguments( - igen: &SsaContext, - lhs: NodeId, - rhs: NodeId, - ins_args: &[NodeId], + operation: &Operation, prev_ins: &VecDeque, ) -> Option { for iter in prev_ins { if let Some(ins) = igen.try_get_instruction(*iter) { - if ins.lhs == lhs && ins.rhs == rhs && ins.ins_arguments == ins_args { + if &ins.operation == operation { return Some(*iter); } } @@ -229,53 +141,47 @@ pub fn find_similar_instruction_with_multiple_arguments( pub fn find_similar_cast( igen: &SsaContext, - lhs: NodeId, + operator: &Operation, res_type: node::ObjectType, prev_ins: &VecDeque, ) -> Option { for iter in prev_ins { if let Some(ins) = igen.try_get_instruction(*iter) { - if ins.lhs == lhs && ins.res_type == res_type { + if &ins.operation == operator && ins.res_type == res_type { return Some(*iter); } } } None } + pub enum CseAction { - Replace, //replace the instruction - Remove, //remove the instruction - Keep, //keep the instruction + ReplaceWith(NodeId), + Remove(NodeId), + Keep, } -//Returns an id and an action: -//- replace => the instruction should be replaced by the returned id -//- remove => the instruction corresponding to the returned id should be removed -//- keep => keep the instruction -pub fn find_similar_mem_instruction( +fn find_similar_mem_instruction( ctx: &SsaContext, - op: node::Operation, - ins_id: NodeId, - lhs: NodeId, - rhs: NodeId, - anchor: &HashMap>, -) -> (NodeId, CseAction) { + op: &Operation, + anchor: &mut Anchor, +) -> CseAction { match op { - node::Operation::Load(_) => { - for iter in anchor[&op].iter().rev() { + Operation::Load { array_id, index } => { + for iter in anchor.get_all(op.opcode()).iter().rev() { if let Some(ins_iter) = ctx.try_get_instruction(*iter) { - match ins_iter.operator { - node::Operation::Load(_) => { - if ins_iter.lhs == lhs { - return (*iter, CseAction::Replace); - } + match &ins_iter.operation { + Operation::Load { array_id: array_id2, index: _ } => { + assert_eq!(array_id, array_id2); + return CseAction::ReplaceWith(*iter); } - node::Operation::Store(_) => { - if ins_iter.rhs == lhs { - return (ins_iter.lhs, CseAction::Replace); + Operation::Store { array_id: array_id2, index: index2, value } => { + assert_eq!(array_id, array_id2); + if index == index2 { + return CseAction::ReplaceWith(*value); } else { //TODO: If we know that ins.lhs value cannot be equal to ins_iter.rhs, we could continue instead - return (ins_id, CseAction::Keep); + return CseAction::Keep; } } _ => unreachable!("invalid operator in the memory anchor list"), @@ -283,21 +189,23 @@ pub fn find_similar_mem_instruction( } } } - node::Operation::Store(x) => { - let prev_ins = &anchor[&node::Operation::Load(x)]; - for iter in prev_ins.iter().rev() { - if let Some(ins_iter) = ctx.try_get_instruction(*iter) { - match ins_iter.operator { - node::Operation::Load(_) => { + Operation::Store { array_id, index, value: _ } => { + let opcode = Opcode::Load(*array_id); + for node_id in anchor.get_all(opcode).iter().rev() { + if let Some(ins_iter) = ctx.try_get_instruction(*node_id) { + match ins_iter.operation { + Operation::Load { array_id: array_id2, .. } => { + assert_eq!(*array_id, array_id2); //TODO: If we know that ins.rhs value cannot be equal to ins_iter.rhs, we could continue instead - return (ins_id, CseAction::Keep); + return CseAction::Keep; } - node::Operation::Store(_) => { - if ins_iter.rhs == rhs { - return (*iter, CseAction::Remove); + Operation::Store { index: index2, array_id: array_id2, .. } => { + assert_eq!(*array_id, array_id2); + if *index == index2 { + return CseAction::Remove(*node_id); } else { //TODO: If we know that ins.rhs value cannot be equal to ins_iter.rhs, we could continue instead - return (ins_id, CseAction::Keep); + return CseAction::Keep; } } _ => unreachable!("invalid operator in the memory anchor list"), @@ -307,33 +215,57 @@ pub fn find_similar_mem_instruction( } _ => unreachable!("invalid non memory operator"), } - (ins_id, CseAction::Keep) + + CseAction::Keep } pub fn propagate(ctx: &SsaContext, id: NodeId) -> NodeId { - let mut result = id; if let Some(obj) = ctx.try_get_instruction(id) { - if obj.operator == node::Operation::Ass || obj.is_deleted { - result = obj.rhs; + if let Mark::ReplaceWith(replacement) = obj.mark { + return replacement; + } else if let Operation::Binary(node::Binary { operator: BinaryOp::Assign, rhs, .. }) = + &obj.operation + { + return *rhs; } } - result + id } //common subexpression elimination, starting from the root pub fn cse(igen: &mut SsaContext, first_block: BlockId) -> Option { - let mut anchor = HashMap::new(); + let mut anchor = Anchor::default(); cse_tree(igen, first_block, &mut anchor) } +/// A list of instructions with the same Operation variant, ordered by the order +/// they appear in their respective blocks. +#[derive(Default, Clone)] +struct Anchor { + map: HashMap>, +} + +impl Anchor { + fn push_front(&mut self, op: &Operation, id: NodeId) { + let key = match op { + Operation::Store { array_id, .. } => Opcode::Load(*array_id), + _ => op.opcode(), + }; + self.map.entry(key).or_insert_with(VecDeque::new).push_front(id); + } + + fn get_all(&self, opcode: Opcode) -> Cow> { + match self.map.get(&opcode) { + Some(vec) => Cow::Borrowed(vec), + None => Cow::Owned(VecDeque::new()), + } + } +} + //Perform CSE for the provided block and then process its children following the dominator tree, passing around the anchor list. -pub fn cse_tree( - igen: &mut SsaContext, - block_id: BlockId, - anchor: &mut HashMap>, -) -> Option { +fn cse_tree(igen: &mut SsaContext, block_id: BlockId, anchor: &mut Anchor) -> Option { let mut instructions = Vec::new(); - let mut res = block_cse(igen, block_id, anchor, &mut instructions); + let mut res = cse_block_with_anchor(igen, block_id, &mut instructions, anchor); for b in igen[block_id].dominated.clone() { let sub_res = cse_tree(igen, b, &mut anchor.clone()); if sub_res.is_some() { @@ -343,21 +275,20 @@ pub fn cse_tree( res } -pub fn anchor_push(op: node::Operation, anchor: &mut HashMap>) { - match op { - node::Operation::Store(x) => { - anchor.entry(node::Operation::Load(x)).or_insert_with(VecDeque::new) - } - _ => anchor.entry(op).or_insert_with(VecDeque::new), - }; +pub fn cse_block( + ctx: &mut SsaContext, + block_id: BlockId, + instructions: &mut Vec, +) -> Option { + cse_block_with_anchor(ctx, block_id, instructions, &mut Anchor::default()) } //Performs common subexpression elimination and copy propagation on a block -pub fn block_cse( +fn cse_block_with_anchor( ctx: &mut SsaContext, block_id: BlockId, - anchor: &mut HashMap>, instructions: &mut Vec, + anchor: &mut Anchor, ) -> Option { let mut new_list = Vec::new(); let bb = &ctx[block_id]; @@ -366,194 +297,101 @@ pub fn block_cse( instructions.append(&mut bb.instructions.clone()); } - for iter in instructions { - if let Some(ins) = ctx.try_get_instruction(*iter) { - let mut to_delete = false; - let mut i_lhs = ins.lhs; - let mut i_rhs = ins.rhs; - let mut phi_args = Vec::new(); - let mut ins_args = Vec::new(); - let mut to_update_phi = false; - let mut to_update = false; - - if ins.is_deleted { + for ins_id in instructions { + if let Some(ins) = ctx.try_get_instruction(*ins_id) { + if ins.is_deleted() { continue; } - anchor_push(ins.operator, anchor); - if node::is_binary(ins.operator) { - //binary operation: - i_lhs = propagate(ctx, ins.lhs); - i_rhs = propagate(ctx, ins.rhs); - if let Some(j) = find_similar_instruction(ctx, i_lhs, i_rhs, &anchor[&ins.operator]) - { - to_delete = true; //we want to delete ins but ins is immutable so we use the new_list instead - i_rhs = j; - } else { - new_list.push(*iter); - anchor.get_mut(&ins.operator).unwrap().push_front(*iter); - } - } else { - match ins.operator { - node::Operation::Load(_) | node::Operation::Store(_) => { - i_lhs = propagate(ctx, ins.lhs); - i_rhs = propagate(ctx, ins.rhs); - let (cse_id, cse_action) = find_similar_mem_instruction( - ctx, - ins.operator, - ins.id, - i_lhs, - i_rhs, - anchor, - ); - match cse_action { - CseAction::Keep => new_list.push(*iter), - CseAction::Replace => { - to_delete = true; - i_rhs = cse_id; - } - CseAction::Remove => { - new_list.push(*iter); - // TODO if not found, it should be removed from other blocks; we could keep a list of instructions to remove - if let Some(pos) = new_list.iter().position(|x| *x == cse_id) { - new_list.remove(pos); - } - } - } - } - node::Operation::Ass => { - //assignement - i_rhs = propagate(ctx, ins.rhs); - to_delete = true; - } - node::Operation::Phi => { - // propagate phi arguments - for a in &ins.phi_arguments { - phi_args.push((propagate(ctx, a.0), a.1)); - if phi_args.last().unwrap().0 != a.0 { - to_update_phi = true; - } - } - if let Some(first) = node::Instruction::simplify_phi(ins.id, &phi_args) { - if first == ins.id { - new_list.push(*iter); - } else { - to_delete = true; - i_rhs = first; - to_update_phi = false; - } - } else { - to_delete = true; - } + + let operator = ins.operation.map_id(|id| propagate(ctx, id)); + + let mut new_mark = Mark::None; + + match &operator { + Operation::Binary(binary) => { + let variants = anchor.get_all(binary.opcode()); + if let Some(similar) = find_similar_instruction(ctx, &operator, &variants) { + new_mark = Mark::ReplaceWith(similar); + } else if binary.operator == BinaryOp::Assign { + new_mark = Mark::ReplaceWith(binary.rhs); + } else { + new_list.push(*ins_id); + anchor.push_front(&ins.operation, *ins_id); } - node::Operation::Cast => { - //Propagate cast argument - i_lhs = propagate(ctx, ins.lhs); - i_rhs = i_lhs; - //Similar cast must have same type - if let Some(j) = - find_similar_cast(ctx, i_lhs, ins.res_type, &anchor[&ins.operator]) - { - to_delete = true; //we want to delete ins but ins is immutable so we use the new_list instead - i_rhs = j; - } else { - new_list.push(*iter); - anchor.get_mut(&ins.operator).unwrap().push_front(*iter); + } + Operation::Load { .. } | Operation::Store { .. } => { + match find_similar_mem_instruction(ctx, &operator, anchor) { + CseAction::Keep => new_list.push(*ins_id), + CseAction::ReplaceWith(new_id) => { + new_mark = Mark::ReplaceWith(new_id); } - } - node::Operation::Call(_) | node::Operation::Ret => { - //No CSE for function calls because of possible side effect - TODO checks if a function has side effect when parsed and do cse for these. - //Propagate arguments: - for a in &ins.ins_arguments { - let new_a = propagate(ctx, *a); - if !to_update && new_a != *a { - to_update = true; + CseAction::Remove(id_to_remove) => { + new_list.push(*ins_id); + // TODO if not found, it should be removed from other blocks; we could keep a list of instructions to remove + if let Some(id) = new_list.iter().position(|x| *x == id_to_remove) { + new_list.remove(id); } - ins_args.push(new_a); } - new_list.push(*iter); } - node::Operation::Intrinsic(_) => { - //n.b this could be the default behovoir for binary operations - for a in &ins.ins_arguments { - let new_a = propagate(ctx, *a); - if !to_update && new_a != *a { - to_update = true; - } - ins_args.push(new_a); - } - i_lhs = propagate(ctx, ins.lhs); - i_rhs = propagate(ctx, ins.rhs); - if let Some(j) = find_similar_instruction_with_multiple_arguments( - ctx, - i_lhs, - i_rhs, - &ins_args, - &anchor[&ins.operator], - ) { - to_delete = true; //we want to delete ins but ins is immutable so we use the new_list instead - i_rhs = j; + } + Operation::Phi { block_args, .. } => { + // propagate phi arguments + if let Some(first) = Instruction::simplify_phi(ins.id, block_args) { + if first == ins.id { + new_list.push(*ins_id); } else { - new_list.push(*iter); - anchor.get_mut(&ins.operator).unwrap().push_front(*iter); + new_mark = Mark::ReplaceWith(first); } + } else { + new_mark = Mark::Deleted; } - _ => { - //TODO: checks we do not need to propagate res arguments - new_list.push(*iter); + } + Operation::Cast(_) => { + //Similar cast must have same type + if let Some(similar) = find_similar_cast( + ctx, + &operator, + ins.res_type, + &anchor.get_all(Opcode::Cast), + ) { + new_mark = Mark::ReplaceWith(similar); + } else { + new_list.push(*ins_id); + anchor.push_front(&operator, *ins_id); } } - } - - if to_update_phi { - let update = ctx.get_mut_instruction(*iter); - update.phi_arguments = phi_args; - } else if to_delete || ins.lhs != i_lhs || ins.rhs != i_rhs || to_update { - //update i: - let ii_l = ins.lhs; - let ii_r = ins.rhs; - let update = ctx.get_mut_instruction(*iter); - update.lhs = i_lhs; - update.rhs = i_rhs; - update.is_deleted = to_delete; - if to_update { - update.ins_arguments = ins_args; + Operation::Call(..) | Operation::Return(..) => { + //No CSE for function calls because of possible side effect - TODO checks if a function has side effect when parsed and do cse for these. + //Propagate arguments: + new_list.push(*ins_id); } - //update instruction name - for debug/pretty print purposes only ///////////////////// - if let Some(Instruction { operator: Operation::Ass, lhs, .. }) = - ctx.try_get_instruction(ii_l) - { - if let Ok(lv) = ctx.get_variable(*lhs) { - let i_name = lv.name.clone(); - if let Some(p_ins) = ctx.try_get_mut_instruction(i_lhs) { - if p_ins.res_name.is_empty() { - p_ins.res_name = i_name; - } - } + Operation::Intrinsic(..) => { + //n.b this could be the default behavior for binary operations + if let Some(similar) = + find_similar_instruction(ctx, &operator, &anchor.get_all(operator.opcode())) + { + new_mark = Mark::ReplaceWith(similar); + } else { + new_list.push(*ins_id); + anchor.push_front(&operator, *ins_id); } } - if let Some(Instruction { operator: Operation::Ass, lhs, .. }) = - ctx.try_get_instruction(ii_r) - { - if let Ok(lv) = ctx.get_variable(*lhs) { - let i_name = lv.name.clone(); - if let Some(p_ins) = ctx.try_get_mut_instruction(i_rhs) { - if p_ins.res_name.is_empty() { - p_ins.res_name = i_name; - } - } - } + _ => { + //TODO: checks we do not need to propagate res arguments + new_list.push(*ins_id); } - ////////////////////////////////////////update instruction name for debug purposes//////////////////////////////// + } + + let update = ctx.get_mut_instruction(*ins_id); + update.operation = operator; + update.mark = new_mark; + if new_mark == Mark::Deleted { + update.operation = Operation::Nop; } } } - let mut last = None; - for i in new_list.iter().rev() { - if is_some(ctx, *i) { - last = Some(*i); - break; - } - } + + let last = new_list.iter().copied().rev().find(|id| is_some(ctx, *id)); ctx[block_id].instructions = new_list; last } @@ -563,7 +401,7 @@ pub fn is_some(ctx: &SsaContext, id: NodeId) -> bool { return false; } if let Some(ins) = ctx.try_get_instruction(id) { - if ins.operator != node::Operation::Nop { + if ins.operation != Operation::Nop { return true; } } else if ctx.try_get_node(id).is_some() { diff --git a/crates/noirc_evaluator/src/ssa/ssa_form.rs b/crates/noirc_evaluator/src/ssa/ssa_form.rs index 8d13031610b..60f1762087f 100644 --- a/crates/noirc_evaluator/src/ssa/ssa_form.rs +++ b/crates/noirc_evaluator/src/ssa/ssa_form.rs @@ -1,3 +1,4 @@ +use crate::ssa::node::{Mark, Operation}; use noirc_frontend::{node_interner::DefinitionId, ArraySize}; use super::{ @@ -16,21 +17,24 @@ pub fn write_phi(ctx: &mut SsaContext, predecessors: &[BlockId], var: NodeId, ph } let s2 = node::Instruction::simplify_phi(phi, &result); if let Some(phi_ins) = ctx.try_get_mut_instruction(phi) { - assert!(phi_ins.phi_arguments.is_empty()); + let phi_args = match &mut phi_ins.operation { + Operation::Phi { block_args, .. } => block_args, + _ => unreachable!(), + }; + + assert_eq!(phi_args.len(), 0); if let Some(s_phi) = s2 { if s_phi != phi { - //s2 != phi - phi_ins.is_deleted = true; - phi_ins.rhs = s_phi; + phi_ins.mark = Mark::ReplaceWith(s_phi); //eventually simplify recursively: if a phi instruction is in phi use list, call simplify_phi() on it //but cse should deal with most of it. } else { //s2 == phi - phi_ins.phi_arguments = result; + *phi_args = result; } } else { //s2 is None - phi_ins.is_deleted = true; + phi_ins.mark = Mark::Deleted; } } } @@ -41,9 +45,9 @@ pub fn seal_block(ctx: &mut SsaContext, block_id: BlockId) { let instructions = block.instructions.clone(); for i in instructions { if let Some(ins) = ctx.try_get_instruction(i) { - let rhs = ins.rhs; - if ins.operator == node::Operation::Phi { - write_phi(ctx, &pred, rhs, i); + if let Operation::Phi { root, .. } = &ins.operation { + let root = *root; + write_phi(ctx, &pred, root, i); } } } diff --git a/crates/noirc_frontend/src/hir_def/expr.rs b/crates/noirc_frontend/src/hir_def/expr.rs index 098a8e55d35..c9ecba823a5 100644 --- a/crates/noirc_frontend/src/hir_def/expr.rs +++ b/crates/noirc_frontend/src/hir_def/expr.rs @@ -66,7 +66,6 @@ pub enum HirBinaryOpKind { Xor, Shl, Shr, - MemberAccess, Assign, } diff --git a/crates/noirc_frontend/src/parser/parser.rs b/crates/noirc_frontend/src/parser/parser.rs index b9f010c9789..816b82fb94d 100644 --- a/crates/noirc_frontend/src/parser/parser.rs +++ b/crates/noirc_frontend/src/parser/parser.rs @@ -1013,7 +1013,6 @@ mod test { #[test] fn parse_parenthesized_expression() { parse_all(atom(expression()), vec!["(0)", "(x+a)", "({(({{({(nested)})}}))})"]); - parse_all_failing(atom(expression()), vec!["(x+a", "((x+a)", "(,)"]); } diff --git a/examples/assign_ex/proofs/p.proof b/examples/assign_ex/proofs/p.proof deleted file mode 100644 index 97a5c2d5fe5..00000000000 --- a/examples/assign_ex/proofs/p.proof +++ /dev/null @@ -1 +0,0 @@ -0b607923732f6ef41420f467ba317c3dc7ef5ae0d2fa6f7e7603616b525fd0d724fcacf980cceb23b4a80a7acdffe54f85c40a6d0ee0afce939cada3d56d4a821bcf1cc1383220ef7e58d9883f650382fb05af91f5a854dbd9882e1a0db23acb03011c9b9ea5b107f6e54b649efa24eba5246082c7f2e8d43d23545f50625eef08b8fa42e293685ca510d70909fabcc729cc0300171007df2f9c87072c12e4ac12ead69d69eb881f23eaa5a71e1a12da04aec4dc0e67658b92c9091cf367fcf72b3ce1d71c303626e540e3a9dd1c7edcfdf019817268c09bad684657442cbaa51167a294bb6f4c314b7dc45caf72102e15f4f034be72027cc6107f0daa7d77ac1046395ad5c5d9490b61843d8aac7e67eb38dce71658260abacdebd792cccbac19c8ac06714cbe449e6bcd943e451da21d6545869a0f0736505e29d15cc05d6c1d02d585a0299cfea34d1a3bd809a485f8da610d4f881068828ad69882356578021e7c575f135355b80f6d1f946aeae21f2860713552e7b6a978b45b9920b1ee11686f1e65420fa45939e01f073cb7bdb09b196cc50c2fbb931d852b8579cb5901164110e74e5506daed8cc3709830b54259b797d8d09ccf6cf8d3e290ee35b82bd5d5dd7fc61fccb57d0f3f55f03675b95b7c2c799f58474df05827c4fe7a5b235f822a0d07888521e0a6c53be5b615bf6d3f32ed245479f4f9245a767f196d28d1918e95ae9b93044526454403cc03c1972a75d466932aa827326ce6f3b5a92347dddfad2a030f403d21608b096e662f27b16dc657c8707558517500e5affd0b0d05cad0fc75e66b4a5ee8626fac1fa5d569ad5f3aa552880375e8e97796f90a0b111efdb8ebc879640979404fe80adf40e9eb9ed1be006ace395207f15bad1277879627028bf68f02e50821b8b3f737fef96ddb018d690562f1b66f214be421c57d0aa023736db5c8f541fcf2f6b9b29a499d8ebb917c08391491048fe1d30053f864fb61386d38423585600c91d872b65c3d74d279a0bf5e942a78c08eb228686c6e3b1aa131d7cbfd429234d81e7142e39c646b291b66054522d7503e111f14d85d06813bf41dd445191af812c5e2de6033d881e52c7fb2e61b69ea1100256f8d9989ebe3959fca7b537b737619ddaf47fb952918cafb9ae196491f91d506ff5f41d91968ce3e5215c228c018d029326150ffa7a6e68ea4b2c4a513cebe1e97a9883be23fcfc376243033856ae90dbcfe1a9ba31fa353b83dd54144b5601efab43889cafff840755053e20106614154f7a997d3d20e01a47e1eb37a876c2ef95b58d6ae388b538352c25cbc2d1cd4387020bfc1d8ad96e50a1777aa4a29126eaa36a6b6b31f66eed31433d747b0513bd7761c13f6396f24e63b3acb74fe2a175323658b0201a88b21f6c640fb81a9b1a325a980cb8923392a69bb2e56fd23c6bec708c8c05d5b4fbff9448b47f922ff37208985a3b10c9afea8d19c01b9250110b68ecadfebb602964776f96b44ef3013a41bca5c4569322d88609ae87623da6241b5d3e44bc7c073ce5bd8e7f2c363c64f05eb2a898f1f2fdbda2ac8a0216e4b08a3e25c7eaeeabfaea90976a20772b90b0e4554dc2d1940fcf10c2eec0575686bd7e82f055571d27da2d6732849415331f99ca975a73674aba569808b0bd9fbbae1c6a04a0e82846cd19ad31ecd8d3d9b23a113bca03752f1e7b426c8 \ No newline at end of file