diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 0d7b9d10f4c..e34e0070d4b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -315,7 +315,12 @@ pub(crate) enum Instruction { /// else_value /// } /// ``` - IfElse { then_condition: ValueId, then_value: ValueId, else_value: ValueId }, + IfElse { + then_condition: ValueId, + then_value: ValueId, + else_condition: ValueId, + else_value: ValueId, + }, /// Creates a new array or slice. /// @@ -632,11 +637,14 @@ impl Instruction { assert_message: assert_message.clone(), } } - Instruction::IfElse { then_condition, then_value, else_value } => Instruction::IfElse { - then_condition: f(*then_condition), - then_value: f(*then_value), - else_value: f(*else_value), - }, + Instruction::IfElse { then_condition, then_value, else_condition, else_value } => { + Instruction::IfElse { + then_condition: f(*then_condition), + then_value: f(*then_value), + else_condition: f(*else_condition), + else_value: f(*else_value), + } + } Instruction::MakeArray { elements, typ } => Instruction::MakeArray { elements: elements.iter().copied().map(f).collect(), typ: typ.clone(), @@ -695,9 +703,10 @@ impl Instruction { | Instruction::RangeCheck { value, .. } => { f(*value); } - Instruction::IfElse { then_condition, then_value, else_value } => { + Instruction::IfElse { then_condition, then_value, else_condition, else_value } => { f(*then_condition); f(*then_value); + f(*else_condition); f(*else_value); } Instruction::MakeArray { elements, typ: _ } => { @@ -860,7 +869,7 @@ impl Instruction { None } } - Instruction::IfElse { then_condition, then_value, else_value } => { + Instruction::IfElse { then_condition, then_value, else_condition, else_value } => { let typ = dfg.type_of_value(*then_value); if let Some(constant) = dfg.get_numeric_constant(*then_condition) { @@ -879,11 +888,13 @@ impl Instruction { if matches!(&typ, Type::Numeric(_)) { let then_condition = *then_condition; + let else_condition = *else_condition; let result = ValueMerger::merge_numeric_values( dfg, block, then_condition, + else_condition, then_value, else_value, ); diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 46ca7bb8c24..72d781c6e95 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -465,8 +465,12 @@ fn simplify_slice_push_back( let mut value_merger = ValueMerger::new(dfg, block, &mut slice_sizes, unknown, None, call_stack); - let new_slice = - value_merger.merge_values(len_not_equals_capacity, set_last_slice_value, new_slice); + let new_slice = value_merger.merge_values( + len_not_equals_capacity, + len_equals_capacity, + set_last_slice_value, + new_slice, + ); SimplifyResult::SimplifiedToMultiple(vec![new_slice_length, new_slice]) } diff --git a/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 6bebd21fe61..5a451301d24 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -209,11 +209,15 @@ fn display_instruction_inner( Instruction::RangeCheck { value, max_bit_size, .. } => { writeln!(f, "range_check {} to {} bits", show(*value), *max_bit_size,) } - Instruction::IfElse { then_condition, then_value, else_value } => { + Instruction::IfElse { then_condition, then_value, else_condition, else_value } => { let then_condition = show(*then_condition); let then_value = show(*then_value); + let else_condition = show(*else_condition); let else_value = show(*else_value); - writeln!(f, "if {then_condition} then {then_value} else {else_value}") + writeln!( + f, + "if {then_condition} then {then_value} else (if {else_condition}) {else_value}" + ) } Instruction::MakeArray { elements, typ } => { write!(f, "make_array [")?; diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index c8dd0e3c5a3..3fbccf93ec9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -537,7 +537,11 @@ impl<'f> Context<'f> { let args = vecmap(then_args.iter().zip(else_args), |(then_arg, else_arg)| { (self.inserter.resolve(*then_arg), self.inserter.resolve(else_arg)) }); - + let else_condition = if let Some(branch) = cond_context.else_branch { + branch.condition + } else { + self.inserter.function.dfg.make_constant(FieldElement::zero(), Type::bool()) + }; let block = self.inserter.function.entry_block(); // Cannot include this in the previous vecmap since it requires exclusive access to self @@ -545,6 +549,7 @@ impl<'f> Context<'f> { let instruction = Instruction::IfElse { then_condition: cond_context.then_branch.condition, then_value: then_arg, + else_condition, else_value: else_arg, }; let call_stack = cond_context.call_stack.clone(); @@ -684,10 +689,13 @@ impl<'f> Context<'f> { ) .first(); + let else_condition = self + .insert_instruction(Instruction::Not(condition), call_stack.clone()); + let instruction = Instruction::IfElse { then_condition: condition, then_value: value, - + else_condition, else_value: previous_value, }; @@ -859,9 +867,11 @@ mod test { v1 = not v0 enable_side_effects u1 1 v3 = cast v0 as Field - v5 = mul v3, Field -1 - v7 = add Field 4, v5 - return v7 + v4 = cast v1 as Field + v6 = mul v3, Field 3 + v8 = mul v4, Field 4 + v9 = add v6, v8 + return v9 } "; @@ -921,12 +931,14 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - v3 = cast v0 as Field - v5 = sub Field 5, v2 - v6 = mul v3, v5 - v7 = add v2, v6 - store v7 at v1 - v8 = not v0 + v3 = not v0 + v4 = cast v0 as Field + v5 = cast v3 as Field + v7 = mul v4, Field 5 + v8 = mul v5, v2 + v9 = add v7, v8 + store v9 at v1 + v10 = not v0 enable_side_effects u1 1 return } @@ -958,19 +970,22 @@ mod test { b0(v0: u1, v1: &mut Field): enable_side_effects v0 v2 = load v1 -> Field - v3 = cast v0 as Field - v5 = sub Field 5, v2 - v6 = mul v3, v5 - v7 = add v2, v6 - store v7 at v1 - v8 = not v0 - enable_side_effects v8 - v9 = load v1 -> Field - v10 = cast v8 as Field - v12 = sub Field 6, v9 - v13 = mul v10, v12 - v14 = add v9, v13 - store v14 at v1 + v3 = not v0 + v4 = cast v0 as Field + v5 = cast v3 as Field + v7 = mul v4, Field 5 + v8 = mul v5, v2 + v9 = add v7, v8 + store v9 at v1 + v10 = not v0 + enable_side_effects v10 + v11 = load v1 -> Field + v12 = cast v10 as Field + v13 = cast v0 as Field + v15 = mul v12, Field 6 + v16 = mul v13, v11 + v17 = add v15, v16 + store v17 at v1 enable_side_effects u1 1 return } @@ -1014,123 +1029,101 @@ mod test { // b7 b8 // ↘ ↙ // b9 - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let b1 = builder.insert_block(); - let b2 = builder.insert_block(); - let b3 = builder.insert_block(); - let b4 = builder.insert_block(); - let b5 = builder.insert_block(); - let b6 = builder.insert_block(); - let b7 = builder.insert_block(); - let b8 = builder.insert_block(); - let b9 = builder.insert_block(); - - let c1 = builder.add_parameter(Type::bool()); - let c4 = builder.add_parameter(Type::bool()); - - let r1 = builder.insert_allocate(Type::field()); - - let store_value = |builder: &mut FunctionBuilder, value: u128| { - let value = builder.field_constant(value); - builder.insert_store(r1, value); - }; - - let test_function = Id::test_new(1); - - let call_test_function = |builder: &mut FunctionBuilder, block: u128| { - let block = builder.field_constant(block); - let load = builder.insert_load(r1, Type::field()); - builder.insert_call(test_function, vec![block, load], Vec::new()); - }; - - let switch_store_and_test_function = - |builder: &mut FunctionBuilder, block, block_number: u128| { - builder.switch_to_block(block); - store_value(builder, block_number); - call_test_function(builder, block_number); - }; - - let switch_and_test_function = - |builder: &mut FunctionBuilder, block, block_number: u128| { - builder.switch_to_block(block); - call_test_function(builder, block_number); - }; - - store_value(&mut builder, 0); - call_test_function(&mut builder, 0); - builder.terminate_with_jmp(b1, vec![]); - - switch_store_and_test_function(&mut builder, b1, 1); - builder.terminate_with_jmpif(c1, b2, b3); - switch_store_and_test_function(&mut builder, b2, 2); - builder.terminate_with_jmp(b4, vec![]); - - switch_store_and_test_function(&mut builder, b3, 3); - builder.terminate_with_jmp(b8, vec![]); - - switch_and_test_function(&mut builder, b4, 4); - builder.terminate_with_jmpif(c4, b5, b6); - - switch_store_and_test_function(&mut builder, b5, 5); - builder.terminate_with_jmp(b7, vec![]); - - switch_store_and_test_function(&mut builder, b6, 6); - builder.terminate_with_jmp(b7, vec![]); - - switch_and_test_function(&mut builder, b7, 7); - builder.terminate_with_jmp(b9, vec![]); - - switch_and_test_function(&mut builder, b8, 8); - builder.terminate_with_jmp(b9, vec![]); + let src = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = allocate -> &mut Field + store Field 0 at v2 + v4 = load v2 -> Field + // call v1(Field 0, v4) + jmp b1() + b1(): + store Field 1 at v2 + v6 = load v2 -> Field + // call v1(Field 1, v6) + jmpif v0 then: b2, else: b3 + b2(): + store Field 2 at v2 + v8 = load v2 -> Field + // call v1(Field 2, v8) + jmp b4() + b4(): + v12 = load v2 -> Field + // call v1(Field 4, v12) + jmpif v1 then: b5, else: b6 + b5(): + store Field 5 at v2 + v14 = load v2 -> Field + // call v1(Field 5, v14) + jmp b7() + b7(): + v18 = load v2 -> Field + // call v1(Field 7, v18) + jmp b9() + b9(): + v22 = load v2 -> Field + // call v1(Field 9, v22) + v23 = load v2 -> Field + return v23 + b6(): + store Field 6 at v2 + v16 = load v2 -> Field + // call v1(Field 6, v16) + jmp b7() + b3(): + store Field 3 at v2 + v10 = load v2 -> Field + // call v1(Field 3, v10) + jmp b8() + b8(): + v20 = load v2 -> Field + // call v1(Field 8, v20) + jmp b9() + }"; - switch_and_test_function(&mut builder, b9, 9); - let load = builder.insert_load(r1, Type::field()); - builder.terminate_with_return(vec![load]); + let ssa = Ssa::from_str(src).unwrap(); - let ssa = builder.finish().flatten_cfg().mem2reg(); + let ssa = ssa.flatten_cfg().mem2reg(); - // Expected results after mem2reg removes the allocation and each load and store: - // - // fn main f0 { - // b0(v0: u1, v1: u1): - // call test_function(Field 0, Field 0) - // call test_function(Field 1, Field 1) - // enable_side_effects v0 - // call test_function(Field 2, Field 2) - // call test_function(Field 4, Field 2) - // v29 = and v0, v1 - // enable_side_effects v29 - // call test_function(Field 5, Field 5) - // v32 = not v1 - // v33 = and v0, v32 - // enable_side_effects v33 - // call test_function(Field 6, Field 6) - // enable_side_effects v0 - // v36 = mul v1, Field 5 - // v37 = mul v32, Field 2 - // v38 = add v36, v37 - // v39 = mul v1, Field 5 - // v40 = mul v32, Field 6 - // v41 = add v39, v40 - // call test_function(Field 7, v42) - // v43 = not v0 - // enable_side_effects v43 - // store Field 3 at v2 - // call test_function(Field 3, Field 3) - // call test_function(Field 8, Field 3) - // enable_side_effects Field 1 - // v47 = mul v0, v41 - // v48 = mul v43, Field 1 - // v49 = add v47, v48 - // v50 = mul v0, v44 - // v51 = mul v43, Field 3 - // v52 = add v50, v51 - // call test_function(Field 9, v53) - // return v54 - // } + let expected = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = allocate -> &mut Field + enable_side_effects v0 + v3 = not v0 + v4 = cast v0 as Field + v5 = cast v3 as Field + v7 = mul v4, Field 2 + v8 = add v7, v5 + v9 = mul v0, v1 + enable_side_effects v9 + v10 = not v9 + v11 = cast v9 as Field + v12 = cast v10 as Field + v14 = mul v11, Field 5 + v15 = mul v12, v8 + v16 = add v14, v15 + v17 = not v1 + v18 = mul v0, v17 + enable_side_effects v18 + v19 = not v18 + v20 = cast v18 as Field + v21 = cast v19 as Field + v23 = mul v20, Field 6 + v24 = mul v21, v16 + v25 = add v23, v24 + enable_side_effects v0 + v26 = not v0 + enable_side_effects v26 + v27 = cast v26 as Field + v28 = cast v0 as Field + v30 = mul v27, Field 3 + v31 = mul v28, v25 + v32 = add v30, v31 + enable_side_effects u1 1 + return v32 + }"; let main = ssa.main(); let ret = match main.dfg[main.entry_block()].terminator() { @@ -1139,7 +1132,9 @@ mod test { }; let merged_values = get_all_constants_reachable_from_instruction(&main.dfg, ret); - assert_eq!(merged_values, vec![1, 3, 5, 6]); + assert_eq!(merged_values, vec![2, 3, 5, 6]); + + assert_normalized_ssa_equals(ssa, expected); } #[test] @@ -1319,23 +1314,20 @@ mod test { v9 = add v7, Field 1 v10 = cast v9 as u8 v11 = load v6 -> u8 - v12 = cast v4 as Field - v13 = cast v11 as Field - v14 = sub v9, v13 - v15 = mul v12, v14 - v16 = add v13, v15 - v17 = cast v16 as u8 + v12 = not v5 + v13 = cast v4 as u8 + v14 = cast v12 as u8 + v15 = mul v13, v10 + v16 = mul v14, v11 + v17 = add v15, v16 store v17 at v6 v18 = not v5 enable_side_effects v18 v19 = load v6 -> u8 - v20 = cast v18 as Field - v21 = cast v19 as Field - v23 = sub Field 0, v21 - v24 = mul v20, v23 - v25 = add v21, v24 - v26 = cast v25 as u8 - store v26 at v6 + v20 = cast v18 as u8 + v21 = cast v4 as u8 + v22 = mul v21, v19 + store v22 at v6 enable_side_effects u1 1 constrain v5 == u1 1 return diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index c97572251db..6ea235b9414 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -54,6 +54,7 @@ impl<'a> ValueMerger<'a> { pub(crate) fn merge_values( &mut self, then_condition: ValueId, + else_condition: ValueId, then_value: ValueId, else_value: ValueId, ) -> ValueId { @@ -69,14 +70,15 @@ impl<'a> ValueMerger<'a> { self.dfg, self.block, then_condition, + else_condition, then_value, else_value, ), typ @ Type::Array(_, _) => { - self.merge_array_values(typ, then_condition, then_value, else_value) + self.merge_array_values(typ, then_condition, else_condition, then_value, else_value) } typ @ Type::Slice(_) => { - self.merge_slice_values(typ, then_condition, then_value, else_value) + self.merge_slice_values(typ, then_condition, else_condition, then_value, else_value) } Type::Reference(_) => panic!("Cannot return references from an if expression"), Type::Function => panic!("Cannot return functions from an if expression"), @@ -84,11 +86,12 @@ impl<'a> ValueMerger<'a> { } /// Merge two numeric values a and b from separate basic blocks to a single value. This - /// function would return the result of `if c { a } else { b }` as `c * (a-b) + b`. + /// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`. pub(crate) fn merge_numeric_values( dfg: &mut DataFlowGraph, block: BasicBlockId, then_condition: ValueId, + else_condition: ValueId, then_value: ValueId, else_value: ValueId, ) -> ValueId { @@ -111,38 +114,31 @@ impl<'a> ValueMerger<'a> { // We must cast the bool conditions to the actual numeric type used by each value. let then_condition = dfg .insert_instruction_and_results( - Instruction::Cast(then_condition, Type::field()), + Instruction::Cast(then_condition, then_type), block, None, call_stack.clone(), ) .first(); - - let then_field = Instruction::Cast(then_value, Type::field()); - let then_field_value = - dfg.insert_instruction_and_results(then_field, block, None, call_stack.clone()).first(); - - let else_field = Instruction::Cast(else_value, Type::field()); - let else_field_value = - dfg.insert_instruction_and_results(else_field, block, None, call_stack.clone()).first(); - - let diff = Instruction::binary(BinaryOp::Sub, then_field_value, else_field_value); - let diff_value = - dfg.insert_instruction_and_results(diff, block, None, call_stack.clone()).first(); - - let conditional_diff = Instruction::binary(BinaryOp::Mul, then_condition, diff_value); - let conditional_diff_value = dfg - .insert_instruction_and_results(conditional_diff, block, None, call_stack.clone()) + let else_condition = dfg + .insert_instruction_and_results( + Instruction::Cast(else_condition, else_type), + block, + None, + call_stack.clone(), + ) .first(); - let merged_field = - Instruction::binary(BinaryOp::Add, else_field_value, conditional_diff_value); - let merged_field_value = dfg - .insert_instruction_and_results(merged_field, block, None, call_stack.clone()) - .first(); + let mul = Instruction::binary(BinaryOp::Mul, then_condition, then_value); + let then_value = + dfg.insert_instruction_and_results(mul, block, None, call_stack.clone()).first(); + + let mul = Instruction::binary(BinaryOp::Mul, else_condition, else_value); + let else_value = + dfg.insert_instruction_and_results(mul, block, None, call_stack.clone()).first(); - let merged = Instruction::Cast(merged_field_value, then_type); - dfg.insert_instruction_and_results(merged, block, None, call_stack).first() + let add = Instruction::binary(BinaryOp::Add, then_value, else_value); + dfg.insert_instruction_and_results(add, block, None, call_stack).first() } /// Given an if expression that returns an array: `if c { array1 } else { array2 }`, @@ -152,6 +148,7 @@ impl<'a> ValueMerger<'a> { &mut self, typ: Type, then_condition: ValueId, + else_condition: ValueId, then_value: ValueId, else_value: ValueId, ) -> ValueId { @@ -166,6 +163,7 @@ impl<'a> ValueMerger<'a> { if let Some(result) = self.try_merge_only_changed_indices( then_condition, + else_condition, then_value, else_value, actual_length, @@ -196,7 +194,12 @@ impl<'a> ValueMerger<'a> { let then_element = get_element(then_value, typevars.clone()); let else_element = get_element(else_value, typevars); - merged.push_back(self.merge_values(then_condition, then_element, else_element)); + merged.push_back(self.merge_values( + then_condition, + else_condition, + then_element, + else_element, + )); } } @@ -209,6 +212,7 @@ impl<'a> ValueMerger<'a> { &mut self, typ: Type, then_condition: ValueId, + else_condition: ValueId, then_value_id: ValueId, else_value_id: ValueId, ) -> ValueId { @@ -269,7 +273,12 @@ impl<'a> ValueMerger<'a> { let else_element = get_element(else_value_id, typevars, else_len * element_types.len() as u32); - merged.push_back(self.merge_values(then_condition, then_element, else_element)); + merged.push_back(self.merge_values( + then_condition, + else_condition, + then_element, + else_element, + )); } } @@ -318,6 +327,7 @@ impl<'a> ValueMerger<'a> { fn try_merge_only_changed_indices( &mut self, then_condition: ValueId, + else_condition: ValueId, then_value: ValueId, else_value: ValueId, array_length: u32, @@ -401,7 +411,8 @@ impl<'a> ValueMerger<'a> { let then_element = get_element(then_value, typevars.clone()); let else_element = get_element(else_value, typevars); - let value = self.merge_values(then_condition, then_element, else_element); + let value = + self.merge_values(then_condition, else_condition, then_element, else_element); array = self.insert_array_set(array, index, value, Some(condition)).first(); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs index 182e6e54d0f..02191801fcd 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs @@ -66,8 +66,9 @@ impl Context { for instruction in instructions { match &function.dfg[instruction] { - Instruction::IfElse { then_condition, then_value, else_value } => { + Instruction::IfElse { then_condition, then_value, else_condition, else_value } => { let then_condition = *then_condition; + let else_condition = *else_condition; let then_value = *then_value; let else_value = *else_value; @@ -84,7 +85,12 @@ impl Context { call_stack, ); - let value = value_merger.merge_values(then_condition, then_value, else_value); + let value = value_merger.merge_values( + then_condition, + else_condition, + then_value, + else_value, + ); let _typ = function.dfg.type_of_value(value); let results = function.dfg.instruction_results(instruction);