diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index edf0461430f..a1c96a3cd23 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -472,9 +472,9 @@ impl Context { self.acir_context.assert_eq_var(lhs, rhs, assert_message.clone())?; } } - Instruction::Cast(value_id, typ) => { - let result_acir_var = self.convert_ssa_cast(value_id, typ, dfg)?; - self.define_result_var(dfg, instruction_id, result_acir_var); + Instruction::Cast(value_id, _) => { + let acir_var = self.convert_numeric_value(*value_id, dfg)?; + self.define_result_var(dfg, instruction_id, acir_var); } Instruction::Call { func, arguments } => { let result_ids = dfg.instruction_results(instruction_id); @@ -1636,41 +1636,6 @@ impl Context { } } - /// Returns an `AcirVar` that is constrained to fit in the target type by truncating the input. - /// If the target cast is to a `NativeField`, no truncation is required so the cast becomes a - /// no-op. - fn convert_ssa_cast( - &mut self, - value_id: &ValueId, - typ: &Type, - dfg: &DataFlowGraph, - ) -> Result { - let (variable, incoming_type) = match self.convert_value(*value_id, dfg) { - AcirValue::Var(variable, typ) => (variable, typ), - AcirValue::DynamicArray(_) | AcirValue::Array(_) => { - unreachable!("Cast is only applied to numerics") - } - }; - let target_numeric = match typ { - Type::Numeric(numeric) => numeric, - _ => unreachable!("Can only cast to a numeric"), - }; - match target_numeric { - NumericType::NativeField => { - // Casting into a Field as a no-op - Ok(variable) - } - NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => { - let max_bit_size = incoming_type.bit_size(); - if max_bit_size <= *bit_size { - // Incoming variable already fits into target bit size - this is a no-op - return Ok(variable); - } - self.acir_context.truncate_var(variable, *bit_size, max_bit_size) - } - } - } - /// Returns an `AcirVar`that is constrained to be result of the truncation. fn convert_ssa_truncate( &mut self, diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index abddbfb74c7..931aee9d079 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -160,7 +160,7 @@ impl DataFlowGraph { call_stack: CallStack, ) -> InsertInstructionResult { use InsertInstructionResult::*; - match instruction.simplify(self, block, ctrl_typevars.clone()) { + match instruction.simplify(self, block, ctrl_typevars.clone(), &call_stack) { SimplifyResult::SimplifiedTo(simplification) => SimplifiedTo(simplification), SimplifyResult::SimplifiedToMultiple(simplification) => { SimplifiedToMultiple(simplification) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 628ad638e64..9691017f04b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -236,10 +236,7 @@ impl Instruction { // In ACIR, a division with a false predicate outputs (0,0), so it cannot replace another instruction unless they have the same predicate bin.operator != BinaryOp::Div } - Cast(_, _) | Not(_) | ArrayGet { .. } | ArraySet { .. } => true, - - // Unclear why this instruction causes problems. - Truncate { .. } => false, + Cast(_, _) | Truncate { .. } | Not(_) | ArrayGet { .. } | ArraySet { .. } => true, // These either have side-effects or interact with memory Constrain(..) @@ -408,6 +405,7 @@ impl Instruction { dfg: &mut DataFlowGraph, block: BasicBlockId, ctrl_typevars: Option>, + call_stack: &CallStack, ) -> SimplifyResult { use SimplifyResult::*; match self { @@ -551,7 +549,7 @@ impl Instruction { } } Instruction::Call { func, arguments } => { - simplify_call(*func, arguments, dfg, block, ctrl_typevars) + simplify_call(*func, arguments, dfg, block, ctrl_typevars, call_stack) } Instruction::EnableSideEffects { condition } => { if let Some(last) = dfg[block].instructions().last().copied() { diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index b07e2df7bd3..edfc50a700f 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -31,6 +31,7 @@ pub(super) fn simplify_call( dfg: &mut DataFlowGraph, block: BasicBlockId, ctrl_typevars: Option>, + call_stack: &CallStack, ) -> SimplifyResult { let intrinsic = match &dfg[func] { Value::Intrinsic(intrinsic) => *intrinsic, @@ -242,7 +243,24 @@ pub(super) fn simplify_call( SimplifyResult::SimplifiedToInstruction(instruction) } Intrinsic::FromField => { - let instruction = Instruction::Cast(arguments[0], ctrl_typevars.unwrap().remove(0)); + let incoming_type = Type::field(); + let target_type = ctrl_typevars.unwrap().remove(0); + + let truncate = Instruction::Truncate { + value: arguments[0], + bit_size: target_type.bit_size(), + max_bit_size: incoming_type.bit_size(), + }; + let truncated_value = dfg + .insert_instruction_and_results( + truncate, + block, + Some(vec![incoming_type]), + call_stack.clone(), + ) + .first(); + + let instruction = Instruction::Cast(truncated_value, target_type); SimplifyResult::SimplifiedToInstruction(instruction) } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/types.rs b/compiler/noirc_evaluator/src/ssa/ir/types.rs index bae06a805d0..ae53c7705c2 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/types.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/types.rs @@ -18,6 +18,28 @@ pub enum NumericType { NativeField, } +impl NumericType { + /// Returns the bit size of the provided numeric type. + pub(crate) fn bit_size(self: &NumericType) -> u32 { + match self { + NumericType::NativeField => FieldElement::max_num_bits(), + NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => *bit_size, + } + } + + /// Returns true if the given Field value is within the numeric limits + /// for the current NumericType. + pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool { + match self { + NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => { + let max = 2u128.pow(bit_size) - 1; + field <= max.into() + } + NumericType::NativeField => true, + } + } +} + /// All types representable in the IR. #[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] pub(crate) enum Type { @@ -68,6 +90,18 @@ impl Type { Type::Numeric(NumericType::NativeField) } + /// Returns the bit size of the provided numeric type. + /// + /// # Panics + /// + /// Panics if `self` is not a [`Type::Numeric`] + pub(crate) fn bit_size(&self) -> u32 { + match self { + Type::Numeric(numeric_type) => numeric_type.bit_size(), + other => panic!("bit_size: Expected numeric type, found {other}"), + } + } + /// Returns the size of the element type for this array/slice. /// The size of a type is defined as representing how many Fields are needed /// to represent the type. This is 1 for every primitive type, and is the number of fields @@ -122,20 +156,6 @@ impl Type { } } -impl NumericType { - /// Returns true if the given Field value is within the numeric limits - /// for the current NumericType. - pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool { - match self { - NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => { - let max = 2u128.pow(bit_size) - 1; - field <= max.into() - } - NumericType::NativeField => true, - } - } -} - /// Composite Types are essentially flattened struct or tuple types. /// Array types may have these as elements where each flattened field is /// included in the array sequentially. diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index e944d7d99d8..57c93c17fc4 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -187,7 +187,7 @@ mod test { instruction::{BinaryOp, Instruction, TerminatorInstruction}, map::Id, types::Type, - value::{Value, ValueId}, + value::Value, }, }; @@ -293,7 +293,7 @@ mod test { #[test] fn instruction_deduplication() { // fn main f0 { - // b0(v0: Field): + // b0(v0: u16): // v1 = cast v0 as u32 // v2 = cast v0 as u32 // constrain v1 v2 @@ -308,7 +308,7 @@ mod test { // Compiling main let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); - let v0 = builder.add_parameter(Type::field()); + let v0 = builder.add_parameter(Type::unsigned(16)); let v1 = builder.insert_cast(v0, Type::unsigned(32)); let v2 = builder.insert_cast(v0, Type::unsigned(32)); @@ -322,7 +322,7 @@ mod test { // Expected output: // // fn main f0 { - // b0(v0: Field): + // b0(v0: u16): // v1 = cast v0 as u32 // } let ssa = ssa.fold_constants(); @@ -332,6 +332,6 @@ mod test { assert_eq!(instructions.len(), 1); let instruction = &main.dfg[instructions[0]]; - assert_eq!(instruction, &Instruction::Cast(ValueId::test_new(0), Type::unsigned(32))); + assert_eq!(instruction, &Instruction::Cast(v0, Type::unsigned(32))); } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index f5cf548ddc9..b34b667c31a 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -275,6 +275,8 @@ impl<'a> FunctionContext<'a> { let bit_width = self.builder.numeric_constant(FieldElement::from(2_i128.pow(bit_size)), Type::field()); let sign_not = self.builder.insert_binary(one, BinaryOp::Sub, sign); + + // We use unsafe casts here, this is fine as we're casting to a `field` type. let as_field = self.builder.insert_cast(input, Type::field()); let sign_field = self.builder.insert_cast(sign, Type::field()); let positive_predicate = self.builder.insert_binary(sign_field, BinaryOp::Mul, as_field); @@ -310,12 +312,12 @@ impl<'a> FunctionContext<'a> { match operator { BinaryOpKind::Add | BinaryOpKind::Subtract => { // Result is computed modulo the bit size - let mut result = - self.builder.insert_truncate(result, bit_size, bit_size + 1); - result = self.builder.insert_cast(result, Type::unsigned(bit_size)); + let result = self.builder.insert_truncate(result, bit_size, bit_size + 1); + let result = + self.insert_safe_cast(result, Type::unsigned(bit_size), location); self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location); - self.builder.insert_cast(result, result_type) + self.insert_safe_cast(result, result_type, location) } BinaryOpKind::Multiply => { // Result is computed modulo the bit size @@ -324,7 +326,7 @@ impl<'a> FunctionContext<'a> { result = self.builder.insert_truncate(result, bit_size, 2 * bit_size); self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location); - self.builder.insert_cast(result, result_type) + self.insert_safe_cast(result, result_type, location) } BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => { self.check_shift_overflow(result, rhs, bit_size, location, true) @@ -374,8 +376,11 @@ impl<'a> FunctionContext<'a> { is_signed: bool, ) -> ValueId { let one = self.builder.numeric_constant(FieldElement::one(), Type::bool()); - let rhs = - if is_signed { self.builder.insert_cast(rhs, Type::unsigned(bit_size)) } else { rhs }; + let rhs = if is_signed { + self.insert_safe_cast(rhs, Type::unsigned(bit_size), location) + } else { + rhs + }; // Bit-shift with a negative number is an overflow if is_signed { // We compute the sign of rhs. @@ -431,8 +436,8 @@ impl<'a> FunctionContext<'a> { Type::unsigned(bit_size), ); // We compute the sign of the operands. The overflow checks for signed integers depends on these signs - let lhs_as_unsigned = self.builder.insert_cast(lhs, Type::unsigned(bit_size)); - let rhs_as_unsigned = self.builder.insert_cast(rhs, Type::unsigned(bit_size)); + let lhs_as_unsigned = self.insert_safe_cast(lhs, Type::unsigned(bit_size), location); + let rhs_as_unsigned = self.insert_safe_cast(rhs, Type::unsigned(bit_size), location); let lhs_sign = self.builder.insert_binary(lhs_as_unsigned, BinaryOp::Lt, half_width); let mut rhs_sign = self.builder.insert_binary(rhs_as_unsigned, BinaryOp::Lt, half_width); let message = if is_sub { @@ -473,7 +478,7 @@ impl<'a> FunctionContext<'a> { // Then we check the signed product fits in a signed integer of bit_size-bits let not_same = self.builder.insert_binary(one, BinaryOp::Sub, same_sign); let not_same_sign_field = - self.builder.insert_cast(not_same, Type::unsigned(bit_size)); + self.insert_safe_cast(not_same, Type::unsigned(bit_size), location); let positive_maximum_with_offset = self.builder.insert_binary(half_width, BinaryOp::Add, not_same_sign_field); let product_overflow_check = @@ -663,6 +668,29 @@ impl<'a> FunctionContext<'a> { reshaped_return_values } + /// Inserts a cast instruction at the end of the current block and returns the results + /// of the cast. + /// + /// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`. + pub(super) fn insert_safe_cast( + &mut self, + mut value: ValueId, + typ: Type, + location: Location, + ) -> ValueId { + self.builder.set_location(location); + + // To ensure that `value` is a valid `typ`, we insert an `Instruction::Truncate` instruction beforehand if + // we're narrowing the type size. + let incoming_type_size = self.builder.type_of_value(value).bit_size(); + let target_type_size = typ.bit_size(); + if target_type_size < incoming_type_size { + value = self.builder.insert_truncate(value, target_type_size, incoming_type_size); + } + + self.builder.insert_cast(value, typ) + } + /// Create a const offset of an address for an array load or store pub(super) fn make_offset(&mut self, mut address: ValueId, offset: u128) -> ValueId { if offset != 0 { diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index c00fbbbcb40..0b8c3a37ef9 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -434,8 +434,8 @@ impl<'a> FunctionContext<'a> { fn codegen_cast(&mut self, cast: &ast::Cast) -> Result { let lhs = self.codegen_non_tuple_expression(&cast.lhs)?; let typ = Self::convert_non_tuple_type(&cast.r#type); - self.builder.set_location(cast.location); - Ok(self.builder.insert_cast(lhs, typ).into()) + + Ok(self.insert_safe_cast(lhs, typ, cast.location).into()) } /// Codegens a for loop, creating three new blocks in the process.