From 1cf6256dbded28d92915aa53b7fb048a91c86a50 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 27 Aug 2024 13:11:16 -0500 Subject: [PATCH 1/6] Add missing constant cases --- .../noirc_frontend/src/elaborator/types.rs | 9 +- .../src/hir/resolution/errors.rs | 9 ++ compiler/noirc_frontend/src/hir_def/types.rs | 111 ++++++++++++------ .../src/monomorphization/mod.rs | 1 + .../arithmetic_generics/src/main.nr | 26 ++++ 5 files changed, 117 insertions(+), 39 deletions(-) diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 44bded6b92f..9c36bff6619 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -445,14 +445,19 @@ impl<'context> Elaborator<'context> { }) } UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), - UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); let rhs = self.convert_expression_type(*rhs); match (lhs, rhs) { (Type::Constant(lhs), Type::Constant(rhs)) => { - Type::Constant(op.function(lhs, rhs)) + if let Some(result) = op.function(lhs, rhs) { + Type::Constant(result) + } else { + self.push_err(ResolverError::OverflowInType { lhs, op, rhs, span }); + Type::Error + } } (lhs, rhs) => { if !self.enable_arithmetic_generics { diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 0aad50d13b2..e67cccb33af 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -120,6 +120,8 @@ pub enum ResolverError { NamedTypeArgs { span: Span, item_kind: &'static str }, #[error("Associated constants may only be a field or integer type")] AssociatedConstantsMustBeNumeric { span: Span }, + #[error("Overflow in `{lhs} {op} {rhs}`")] + OverflowInType { lhs: u32, op: crate::BinaryTypeOperator, rhs: u32, span: Span }, } impl ResolverError { @@ -480,6 +482,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) } + ResolverError::OverflowInType { lhs, op, rhs, span } => { + Diagnostic::simple_error( + format!("Overflow in `{lhs} {op} {rhs}`"), + "Overflow here".to_string(), + *span, + ) + } } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 807666f9af9..a19f5d9bc09 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1672,7 +1672,9 @@ impl Type { // evaluate_to_u32 also calls canonicalize so if we just called // `self.evaluate_to_u32()` we'd get infinite recursion. if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { - return Type::Constant(op.function(lhs, rhs)); + if let Some(result) = op.function(lhs, rhs) { + return Type::Constant(result); + } } let lhs = lhs.canonicalize(); @@ -1681,7 +1683,9 @@ impl Type { return result; } - if let Some(result) = Self::try_simplify_subtraction(&lhs, op, &rhs) { + // Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + // where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + if let Some(result) = Self::try_simplify_partial_constants(&lhs, op, &rhs) { return result; } @@ -1711,7 +1715,11 @@ impl Type { queue.push(*rhs); } Type::Constant(new_constant) => { - constant = op.function(constant, new_constant); + if let Some(result) = op.function(constant, new_constant) { + constant = result; + } else { + sorted.insert(Type::Constant(new_constant)); + } } other => { sorted.insert(other); @@ -1757,27 +1765,60 @@ impl Type { } } - /// Try to simplify a subtraction expression of `lhs - rhs`. + /// Given: + /// lhs = `N op C1` + /// rhs = C2 + /// Returns: `(N, op, C1, C2)` if C1 and C2 are constants. + /// Note that the operator here is within the `lhs` term, the operator + /// separating lhs and rhs is not needed. + fn parse_constant_expr( + lhs: &Type, + rhs: &Type, + ) -> Option<(Box, BinaryTypeOperator, u32, u32)> { + let rhs = rhs.evaluate_to_u32()?; + + let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + let l_rhs = l_rhs.evaluate_to_u32()?; + Some((l_type, l_op, l_rhs, rhs)) + } + + /// Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + /// where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) /// - /// - Simplifies `(a + C1) - C2` to `a + (C1 - C2)` if C1 and C2 are constants. - fn try_simplify_subtraction(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { + /// - Simplifies `(a +/- C1) +/- C2` to `a +/- (C1 +/- C2)` if C1 and C2 are constants. + /// - Simplifies `(a */÷ C1) */÷ C2` to `a */÷ (C1 */÷ C2)` if C1 and C2 are constants. + fn try_simplify_partial_constants( + lhs: &Type, + mut op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { use BinaryTypeOperator::*; - match lhs { - Type::InfixExpr(l_lhs, l_op, l_rhs) => { - // Simplify `(N + 2) - 1` - if op == Subtraction && *l_op == Addition { - if let (Some(lhs_const), Some(rhs_const)) = - (l_rhs.evaluate_to_u32(), rhs.evaluate_to_u32()) - { - if lhs_const > rhs_const { - let constant = Box::new(Type::Constant(lhs_const - rhs_const)); - return Some( - Type::InfixExpr(l_lhs.clone(), *l_op, constant).canonicalize(), - ); - } - } + let (l_type, l_op, l_const, r_const) = Type::parse_constant_expr(lhs, rhs)?; + + match (l_op, op) { + (Addition | Subtraction, Addition | Subtraction) => { + // If l_op is a subtraction we want to inverse the rhs operator. + if l_op == Subtraction { + op = op.inverse()?; + } + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + (Multiplication | Division, Multiplication | Division) => { + // If l_op is a subtraction we want to inverse the rhs operator. + if l_op == Division { + op = op.inverse()?; + } + // If op is a division we need to ensure it divides evenly + if op == Division && l_const % r_const != 0 { + None + } else { + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) } - None } _ => None, } @@ -1926,7 +1967,7 @@ impl Type { Type::InfixExpr(lhs, op, rhs) => { let lhs = lhs.evaluate_to_u32()?; let rhs = rhs.evaluate_to_u32()?; - Some(op.function(lhs, rhs)) + op.function(lhs, rhs) } _ => None, } @@ -2030,17 +2071,13 @@ impl Type { Type::Forall(typevars, typ) => { assert_eq!(types.len() + implicit_generic_count, typevars.len(), "Turbofish operator used with incorrect generic count which was not caught by name resolution"); + let bindings = + (0..implicit_generic_count).map(|_| interner.next_type_variable()).chain(types); + let replacements = typevars .iter() - .enumerate() - .map(|(i, var)| { - let binding = if i < implicit_generic_count { - interner.next_type_variable() - } else { - types[i - implicit_generic_count].clone() - }; - (var.id(), (var.clone(), binding)) - }) + .zip(bindings) + .map(|(var, binding)| (var.id(), (var.clone(), binding))) .collect(); let instantiated = typ.substitute(&replacements); @@ -2457,13 +2494,13 @@ fn convert_array_expression_to_slice( impl BinaryTypeOperator { /// Perform the actual rust numeric operation associated with this operator - pub fn function(self, a: u32, b: u32) -> u32 { + pub fn function(self, a: u32, b: u32) -> Option { match self { - BinaryTypeOperator::Addition => a.wrapping_add(b), - BinaryTypeOperator::Subtraction => a.wrapping_sub(b), - BinaryTypeOperator::Multiplication => a.wrapping_mul(b), - BinaryTypeOperator::Division => a.wrapping_div(b), - BinaryTypeOperator::Modulo => a.wrapping_rem(b), + BinaryTypeOperator::Addition => a.checked_add(b), + BinaryTypeOperator::Subtraction => a.checked_sub(b), + BinaryTypeOperator::Multiplication => a.checked_mul(b), + BinaryTypeOperator::Division => a.checked_div(b), + BinaryTypeOperator::Modulo => a.checked_rem(b), } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index edb831b2158..7ebca79e171 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -858,6 +858,7 @@ impl<'interner> Monomorphizer<'interner> { generics.unwrap_or_default(), None, ); + eprintln!("Converting type of function: {typ:?}"); let typ = Self::convert_type(&typ, ident.location)?; let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; let ident_expression = ast::Expression::Ident(ident); diff --git a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr index 6cd13ab0e2f..028581dd39f 100644 --- a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr +++ b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr @@ -7,6 +7,9 @@ fn main() { let _ = split_first([1, 2, 3]); let _ = push_multiple([1, 2, 3]); + + test_constant_folding(); + test_non_constant_folding(); } fn split_first(array: [T; N]) -> (T, [T; N - 1]) { @@ -101,3 +104,26 @@ fn demo_proof() -> Equiv, (Equiv, (), W, () let p3: Equiv, (), W, ()> = add_equiv_r::(p3_sub); equiv_trans(equiv_trans(p1, p2), p3) } + +// Helper function as a quick way to make a particular numeric generic +fn value() -> W { + W {} +} + +fn test_constant_folding() { + // N + C1 - C2 = N + (C1 - C2) + let _: W = value::(); + + // N - C1 + C2 = N - (C1 - C2) + let _: W = value::(); + + // N * C1 / C2 = N * (C1 / C2) + let _: W = value::(); + + // N / C1 * C2 = N / (C1 / C2) + let _: W = value::(); +} + +fn test_non_constant_folding() { + // TODO +} From 2baefceafbdc113672af1ba5cb37c0ac3f390e65 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 27 Aug 2024 13:37:36 -0500 Subject: [PATCH 2/6] Add non-constant cases --- compiler/noirc_frontend/src/hir_def/types.rs | 36 +++++++++---------- .../src/monomorphization/mod.rs | 1 - .../arithmetic_generics/src/main.nr | 31 +++++++++------- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index a19f5d9bc09..9b6979b68f8 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1679,7 +1679,7 @@ impl Type { let lhs = lhs.canonicalize(); let rhs = rhs.canonicalize(); - if let Some(result) = Self::try_simplify_addition(&lhs, op, &rhs) { + if let Some(result) = Self::try_simplify_non_constants(&lhs, op, &rhs) { return result; } @@ -1745,24 +1745,22 @@ impl Type { } } - /// Try to simplify an addition expression of `lhs + rhs`. + /// Try to simplify non-constant expressions in the form `(N op1 M) op2 M` + /// where the two `M` terms are expected to cancel out. /// - /// - Simplifies `(a - b) + b` to `a`. - fn try_simplify_addition(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { - use BinaryTypeOperator::*; - match lhs { - Type::InfixExpr(l_lhs, l_op, l_rhs) => { - if op == Addition && *l_op == Subtraction { - // TODO: Propagate type bindings. Can do in another PR, this one is large enough. - let unifies = l_rhs.try_unify(rhs, &mut TypeBindings::new()); - if unifies.is_ok() { - return Some(l_lhs.as_ref().clone()); - } - } - None - } - _ => None, + /// - Simplifies `(N +/- M) -/+ M` to `a` + /// - Simplifies `(N */÷ M) ÷/* M` to `a` + fn try_simplify_non_constants(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { + let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + // Note that this is exact, syntactic equality, not unification. + if l_op.inverse() != Some(op) || l_rhs.canonicalize() != rhs.canonicalize() { + return None; } + + Some(*l_type) } /// Given: @@ -1788,8 +1786,8 @@ impl Type { /// Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` /// where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) /// - /// - Simplifies `(a +/- C1) +/- C2` to `a +/- (C1 +/- C2)` if C1 and C2 are constants. - /// - Simplifies `(a */÷ C1) */÷ C2` to `a */÷ (C1 */÷ C2)` if C1 and C2 are constants. + /// - Simplifies `(N +/- C1) +/- C2` to `N +/- (C1 +/- C2)` if C1 and C2 are constants. + /// - Simplifies `(N */÷ C1) */÷ C2` to `N */÷ (C1 */÷ C2)` if C1 and C2 are constants. fn try_simplify_partial_constants( lhs: &Type, mut op: BinaryTypeOperator, diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 7ebca79e171..edb831b2158 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -858,7 +858,6 @@ impl<'interner> Monomorphizer<'interner> { generics.unwrap_or_default(), None, ); - eprintln!("Converting type of function: {typ:?}"); let typ = Self::convert_type(&typ, ident.location)?; let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; let ident_expression = ast::Expression::Ident(ident); diff --git a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr index 028581dd39f..ad8dff6c7b9 100644 --- a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr +++ b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr @@ -8,8 +8,8 @@ fn main() { let _ = push_multiple([1, 2, 3]); - test_constant_folding(); - test_non_constant_folding(); + test_constant_folding::<10>(); + test_non_constant_folding::<10, 20>(); } fn split_first(array: [T; N]) -> (T, [T; N - 1]) { @@ -105,25 +105,30 @@ fn demo_proof() -> Equiv, (Equiv, (), W, () equiv_trans(equiv_trans(p1, p2), p3) } -// Helper function as a quick way to make a particular numeric generic -fn value() -> W { - W {} -} - fn test_constant_folding() { // N + C1 - C2 = N + (C1 - C2) - let _: W = value::(); + let _: W = W:: {}; // N - C1 + C2 = N - (C1 - C2) - let _: W = value::(); + let _: W = W:: {}; // N * C1 / C2 = N * (C1 / C2) - let _: W = value::(); + let _: W = W:: {}; // N / C1 * C2 = N / (C1 / C2) - let _: W = value::(); + let _: W = W:: {}; } -fn test_non_constant_folding() { - // TODO +fn test_non_constant_folding() { + // N + M - M = N + let _: W = W:: {}; + + // N - M + M = N + let _: W = W:: {}; + + // N * M / M = N + let _: W = W:: {}; + + // N / M * M = N + let _: W = W:: {}; } From 6b62a9c5135b58008bae364994ee09b5024e5925 Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 28 Aug 2024 10:57:23 -0500 Subject: [PATCH 3/6] Update compiler/noirc_frontend/src/hir_def/types.rs Co-authored-by: Michael J Klein --- compiler/noirc_frontend/src/hir_def/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 9b6979b68f8..8f9da53ae45 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1806,7 +1806,7 @@ impl Type { Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) } (Multiplication | Division, Multiplication | Division) => { - // If l_op is a subtraction we want to inverse the rhs operator. + // If l_op is a division we want to inverse the rhs operator. if l_op == Division { op = op.inverse()?; } From c12f464edcf8baf513708d994b7a5d9f81666d93 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 28 Aug 2024 10:59:13 -0500 Subject: [PATCH 4/6] Add divide by zero check --- compiler/noirc_frontend/src/hir_def/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 8f9da53ae45..b9d9f333841 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1811,7 +1811,7 @@ impl Type { op = op.inverse()?; } // If op is a division we need to ensure it divides evenly - if op == Division && l_const % r_const != 0 { + if op == Division && &&(r_const == 0 || l_const % r_const != 0) { None } else { let result = op.function(l_const, r_const)?; From 76f5d1e1561388bb414d6b21e640d74c6c2232bf Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 28 Aug 2024 12:19:38 -0500 Subject: [PATCH 5/6] Fix extra operator --- compiler/noirc_frontend/src/hir/type_check/generics.rs | 1 + compiler/noirc_frontend/src/hir_def/types.rs | 2 +- compiler/noirc_frontend/src/monomorphization/errors.rs | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/type_check/generics.rs b/compiler/noirc_frontend/src/hir/type_check/generics.rs index 379c53944e5..697c78745f9 100644 --- a/compiler/noirc_frontend/src/hir/type_check/generics.rs +++ b/compiler/noirc_frontend/src/hir/type_check/generics.rs @@ -160,6 +160,7 @@ fn fmt_trait_generics( write!(f, "{} = {}", named.name, named.typ)?; } } + write!(f, ">")?; } Ok(()) } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index b9d9f333841..03caae5e9b3 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1811,7 +1811,7 @@ impl Type { op = op.inverse()?; } // If op is a division we need to ensure it divides evenly - if op == Division && &&(r_const == 0 || l_const % r_const != 0) { + if op == Division && (r_const == 0 || l_const % r_const != 0) { None } else { let result = op.function(l_const, r_const)?; diff --git a/compiler/noirc_frontend/src/monomorphization/errors.rs b/compiler/noirc_frontend/src/monomorphization/errors.rs index 665bf26f7b9..ce8ef3572e6 100644 --- a/compiler/noirc_frontend/src/monomorphization/errors.rs +++ b/compiler/noirc_frontend/src/monomorphization/errors.rs @@ -34,7 +34,7 @@ impl MonomorphizationError { fn into_diagnostic(self) -> CustomDiagnostic { let message = match &self { MonomorphizationError::UnknownArrayLength { length, .. } => { - format!("ICE: Could not determine array length `{length}`") + format!("Could not determine array length `{length}`") } MonomorphizationError::NoDefaultType { location } => { let message = "Type annotation needed".into(); From a2b6457454c64ff1bb3594cfb80bf8f27eee8445 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Wed, 28 Aug 2024 14:06:39 -0500 Subject: [PATCH 6/6] Add custom Eq & Hash impl for type for canonicalize check --- compiler/noirc_frontend/src/hir_def/types.rs | 302 ++++++++---------- .../src/hir_def/types/arithmetic.rs | 215 +++++++++++++ .../src/monomorphization/mod.rs | 2 + 3 files changed, 353 insertions(+), 166 deletions(-) create mode 100644 compiler/noirc_frontend/src/hir_def/types/arithmetic.rs diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 03caae5e9b3..c59c86b9616 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -24,7 +24,9 @@ use super::{ traits::NamedType, }; -#[derive(PartialEq, Eq, Clone, Hash, Ord, PartialOrd)] +mod arithmetic; + +#[derive(Eq, Clone, Ord, PartialOrd)] pub enum Type { /// A primitive Field type FieldElement, @@ -1657,171 +1659,6 @@ impl Type { } } - /// Try to canonicalize the representation of this type. - /// Currently the only type with a canonical representation is - /// `Type::Infix` where for each consecutive commutative operator - /// we sort the non-constant operands by `Type: Ord` and place all constant - /// operands at the end, constant folded. - /// - /// For example: - /// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3` - /// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2` - pub fn canonicalize(&self) -> Type { - match self.follow_bindings() { - Type::InfixExpr(lhs, op, rhs) => { - // evaluate_to_u32 also calls canonicalize so if we just called - // `self.evaluate_to_u32()` we'd get infinite recursion. - if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { - if let Some(result) = op.function(lhs, rhs) { - return Type::Constant(result); - } - } - - let lhs = lhs.canonicalize(); - let rhs = rhs.canonicalize(); - if let Some(result) = Self::try_simplify_non_constants(&lhs, op, &rhs) { - return result; - } - - // Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` - // where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) - if let Some(result) = Self::try_simplify_partial_constants(&lhs, op, &rhs) { - return result; - } - - if op.is_commutative() { - return Self::sort_commutative(&lhs, op, &rhs); - } - - Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) - } - other => other, - } - } - - fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type { - let mut queue = vec![lhs.clone(), rhs.clone()]; - - let mut sorted = BTreeSet::new(); - - let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 }; - let mut constant = zero_value; - - // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. - while let Some(item) = queue.pop() { - match item.canonicalize() { - Type::InfixExpr(lhs, new_op, rhs) if new_op == op => { - queue.push(*lhs); - queue.push(*rhs); - } - Type::Constant(new_constant) => { - if let Some(result) = op.function(constant, new_constant) { - constant = result; - } else { - sorted.insert(Type::Constant(new_constant)); - } - } - other => { - sorted.insert(other); - } - } - } - - if let Some(first) = sorted.pop_first() { - let mut typ = first.clone(); - - for rhs in sorted { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); - } - - if constant != zero_value { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant))); - } - - typ - } else { - // Every type must have been a constant - Type::Constant(constant) - } - } - - /// Try to simplify non-constant expressions in the form `(N op1 M) op2 M` - /// where the two `M` terms are expected to cancel out. - /// - /// - Simplifies `(N +/- M) -/+ M` to `a` - /// - Simplifies `(N */÷ M) ÷/* M` to `a` - fn try_simplify_non_constants(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { - let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { - return None; - }; - - // Note that this is exact, syntactic equality, not unification. - if l_op.inverse() != Some(op) || l_rhs.canonicalize() != rhs.canonicalize() { - return None; - } - - Some(*l_type) - } - - /// Given: - /// lhs = `N op C1` - /// rhs = C2 - /// Returns: `(N, op, C1, C2)` if C1 and C2 are constants. - /// Note that the operator here is within the `lhs` term, the operator - /// separating lhs and rhs is not needed. - fn parse_constant_expr( - lhs: &Type, - rhs: &Type, - ) -> Option<(Box, BinaryTypeOperator, u32, u32)> { - let rhs = rhs.evaluate_to_u32()?; - - let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { - return None; - }; - - let l_rhs = l_rhs.evaluate_to_u32()?; - Some((l_type, l_op, l_rhs, rhs)) - } - - /// Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` - /// where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) - /// - /// - Simplifies `(N +/- C1) +/- C2` to `N +/- (C1 +/- C2)` if C1 and C2 are constants. - /// - Simplifies `(N */÷ C1) */÷ C2` to `N */÷ (C1 */÷ C2)` if C1 and C2 are constants. - fn try_simplify_partial_constants( - lhs: &Type, - mut op: BinaryTypeOperator, - rhs: &Type, - ) -> Option { - use BinaryTypeOperator::*; - let (l_type, l_op, l_const, r_const) = Type::parse_constant_expr(lhs, rhs)?; - - match (l_op, op) { - (Addition | Subtraction, Addition | Subtraction) => { - // If l_op is a subtraction we want to inverse the rhs operator. - if l_op == Subtraction { - op = op.inverse()?; - } - let result = op.function(l_const, r_const)?; - Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) - } - (Multiplication | Division, Multiplication | Division) => { - // If l_op is a division we want to inverse the rhs operator. - if l_op == Division { - op = op.inverse()?; - } - // If op is a division we need to ensure it divides evenly - if op == Division && (r_const == 0 || l_const % r_const != 0) { - None - } else { - let result = op.function(l_const, r_const)?; - Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) - } - } - _ => None, - } - } - /// Try to unify a type variable to `self`. /// This is a helper function factored out from try_unify. fn try_unify_to_type_variable( @@ -2716,3 +2553,136 @@ impl std::fmt::Debug for StructType { write!(f, "{}", self.name) } } + +impl std::hash::Hash for Type { + fn hash(&self, state: &mut H) { + if let Some(variable) = self.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + typ.hash(state); + return; + } + } + + if !matches!(self, Type::TypeVariable(..) | Type::NamedGeneric(..)) { + std::mem::discriminant(self).hash(state); + } + + match self { + Type::FieldElement | Type::Bool | Type::Unit | Type::Error => (), + Type::Array(len, elem) => { + len.hash(state); + elem.hash(state); + } + Type::Slice(elem) => elem.hash(state), + Type::Integer(sign, bits) => { + sign.hash(state); + bits.hash(state); + } + Type::String(len) => len.hash(state), + Type::FmtString(len, env) => { + len.hash(state); + env.hash(state); + } + Type::Tuple(elems) => elems.hash(state), + Type::Struct(def, args) => { + def.hash(state); + args.hash(state); + } + Type::Alias(alias, args) => { + alias.hash(state); + args.hash(state); + } + Type::TypeVariable(var, _) | Type::NamedGeneric(var, ..) => var.hash(state), + Type::TraitAsType(trait_id, _, args) => { + trait_id.hash(state); + args.hash(state); + } + Type::Function(args, ret, env, is_unconstrained) => { + args.hash(state); + ret.hash(state); + env.hash(state); + is_unconstrained.hash(state); + } + Type::MutableReference(elem) => elem.hash(state), + Type::Forall(vars, typ) => { + vars.hash(state); + typ.hash(state); + } + Type::Constant(value) => value.hash(state), + Type::Quoted(typ) => typ.hash(state), + Type::InfixExpr(lhs, op, rhs) => { + lhs.hash(state); + op.hash(state); + rhs.hash(state); + } + } + } +} + +impl PartialEq for Type { + fn eq(&self, other: &Self) -> bool { + if let Some(variable) = self.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + return typ == other; + } + } + + if let Some(variable) = other.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + return self == typ; + } + } + + use Type::*; + match (self, other) { + (FieldElement, FieldElement) | (Bool, Bool) | (Unit, Unit) | (Error, Error) => true, + (Array(lhs_len, lhs_elem), Array(rhs_len, rhs_elem)) => { + lhs_len == rhs_len && lhs_elem == rhs_elem + } + (Slice(lhs_elem), Slice(rhs_elem)) => lhs_elem == rhs_elem, + (Integer(lhs_sign, lhs_bits), Integer(rhs_sign, rhs_bits)) => { + lhs_sign == rhs_sign && lhs_bits == rhs_bits + } + (String(lhs_len), String(rhs_len)) => lhs_len == rhs_len, + (FmtString(lhs_len, lhs_env), FmtString(rhs_len, rhs_env)) => { + lhs_len == rhs_len && lhs_env == rhs_env + } + (Tuple(lhs_types), Tuple(rhs_types)) => lhs_types == rhs_types, + (Struct(lhs_struct, lhs_generics), Struct(rhs_struct, rhs_generics)) => { + lhs_struct == rhs_struct && lhs_generics == rhs_generics + } + (Alias(lhs_alias, lhs_generics), Alias(rhs_alias, rhs_generics)) => { + lhs_alias == rhs_alias && lhs_generics == rhs_generics + } + (TraitAsType(lhs_trait, _, lhs_generics), TraitAsType(rhs_trait, _, rhs_generics)) => { + lhs_trait == rhs_trait && lhs_generics == rhs_generics + } + ( + Function(lhs_args, lhs_ret, lhs_env, lhs_unconstrained), + Function(rhs_args, rhs_ret, rhs_env, rhs_unconstrained), + ) => { + let args_and_ret_eq = lhs_args == rhs_args && lhs_ret == rhs_ret; + args_and_ret_eq && lhs_env == rhs_env && lhs_unconstrained == rhs_unconstrained + } + (MutableReference(lhs_elem), MutableReference(rhs_elem)) => lhs_elem == rhs_elem, + (Forall(lhs_vars, lhs_type), Forall(rhs_vars, rhs_type)) => { + lhs_vars == rhs_vars && lhs_type == rhs_type + } + (Constant(lhs), Constant(rhs)) => lhs == rhs, + (Quoted(lhs), Quoted(rhs)) => lhs == rhs, + (InfixExpr(l_lhs, l_op, l_rhs), InfixExpr(r_lhs, r_op, r_rhs)) => { + l_lhs == r_lhs && l_op == r_op && l_rhs == r_rhs + } + // Special case: we consider unbound named generics and type variables to be equal to each + // other if their type variable ids match. This is important for some corner cases in + // monomorphization where we call `replace_named_generics_with_type_variables` but + // still want them to be equal for canonicalization checks in arithmetic generics. + // Without this we'd fail the `serialize` test. + ( + NamedGeneric(lhs_var, _, _) | TypeVariable(lhs_var, _), + NamedGeneric(rhs_var, _, _) | TypeVariable(rhs_var, _), + ) => lhs_var.id() == rhs_var.id(), + _ => false, + } + } +} diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs new file mode 100644 index 00000000000..ad07185dff1 --- /dev/null +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -0,0 +1,215 @@ +use std::collections::BTreeSet; + +use crate::{BinaryTypeOperator, Type}; + +impl Type { + /// Try to canonicalize the representation of this type. + /// Currently the only type with a canonical representation is + /// `Type::Infix` where for each consecutive commutative operator + /// we sort the non-constant operands by `Type: Ord` and place all constant + /// operands at the end, constant folded. + /// + /// For example: + /// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3` + /// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2` + pub fn canonicalize(&self) -> Type { + match self.follow_bindings() { + Type::InfixExpr(lhs, op, rhs) => { + // evaluate_to_u32 also calls canonicalize so if we just called + // `self.evaluate_to_u32()` we'd get infinite recursion. + if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { + if let Some(result) = op.function(lhs, rhs) { + return Type::Constant(result); + } + } + + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + if let Some(result) = Self::try_simplify_non_constants_in_lhs(&lhs, op, &rhs) { + return result.canonicalize(); + } + + if let Some(result) = Self::try_simplify_non_constants_in_rhs(&lhs, op, &rhs) { + return result.canonicalize(); + } + + // Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + // where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + if let Some(result) = Self::try_simplify_partial_constants(&lhs, op, &rhs) { + return result.canonicalize(); + } + + if op.is_commutative() { + return Self::sort_commutative(&lhs, op, &rhs); + } + + Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) + } + other => other, + } + } + + fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type { + let mut queue = vec![lhs.clone(), rhs.clone()]; + + let mut sorted = BTreeSet::new(); + + let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 }; + let mut constant = zero_value; + + // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. + while let Some(item) = queue.pop() { + match item.canonicalize() { + Type::InfixExpr(lhs, new_op, rhs) if new_op == op => { + queue.push(*lhs); + queue.push(*rhs); + } + Type::Constant(new_constant) => { + if let Some(result) = op.function(constant, new_constant) { + constant = result; + } else { + sorted.insert(Type::Constant(new_constant)); + } + } + other => { + sorted.insert(other); + } + } + } + + if let Some(first) = sorted.pop_first() { + let mut typ = first.clone(); + + for rhs in sorted { + typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); + } + + if constant != zero_value { + typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant))); + } + + typ + } else { + // Every type must have been a constant + Type::Constant(constant) + } + } + + /// Try to simplify non-constant expressions in the form `(N op1 M) op2 M` + /// where the two `M` terms are expected to cancel out. + /// Precondition: `lhs & rhs are in canonical form` + /// + /// - Simplifies `(N +/- M) -/+ M` to `N` + /// - Simplifies `(N */÷ M) ÷/* M` to `N` + fn try_simplify_non_constants_in_lhs( + lhs: &Type, + op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + let Type::InfixExpr(l_lhs, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + // Note that this is exact, syntactic equality, not unification. + // `rhs` is expected to already be in canonical form. + if l_op.inverse() != Some(op) || l_rhs.canonicalize() != *rhs { + return None; + } + + Some(*l_lhs) + } + + /// Try to simplify non-constant expressions in the form `N op1 (M op1 N)` + /// where the two `M` terms are expected to cancel out. + /// Precondition: `lhs & rhs are in canonical form` + /// + /// Unlike `try_simplify_non_constants_in_lhs` we can't simplify `N / (M * N)` + /// Since that should simplify to `1 / M` instead of `M`. + /// + /// - Simplifies `N +/- (M -/+ N)` to `M` + /// - Simplifies `N * (M ÷ N)` to `M` + fn try_simplify_non_constants_in_rhs( + lhs: &Type, + op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + let Type::InfixExpr(r_lhs, r_op, r_rhs) = rhs.follow_bindings() else { + return None; + }; + + // `N / (M * N)` should be simplified to `1 / M`, but we only handle + // simplifying to `M` in this function. + if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication { + return None; + } + + // Note that this is exact, syntactic equality, not unification. + // `lhs` is expected to already be in canonical form. + if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize() { + return None; + } + + Some(*r_lhs) + } + + /// Given: + /// lhs = `N op C1` + /// rhs = C2 + /// Returns: `(N, op, C1, C2)` if C1 and C2 are constants. + /// Note that the operator here is within the `lhs` term, the operator + /// separating lhs and rhs is not needed. + /// Precondition: `lhs & rhs are in canonical form` + fn parse_partial_constant_expr( + lhs: &Type, + rhs: &Type, + ) -> Option<(Box, BinaryTypeOperator, u32, u32)> { + let rhs = rhs.evaluate_to_u32()?; + + let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + let l_rhs = l_rhs.evaluate_to_u32()?; + Some((l_type, l_op, l_rhs, rhs)) + } + + /// Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + /// where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + /// Precondition: `lhs & rhs are in canonical form` + /// + /// - Simplifies `(N +/- C1) +/- C2` to `N +/- (C1 +/- C2)` if C1 and C2 are constants. + /// - Simplifies `(N */÷ C1) */÷ C2` to `N */÷ (C1 */÷ C2)` if C1 and C2 are constants. + fn try_simplify_partial_constants( + lhs: &Type, + mut op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + use BinaryTypeOperator::*; + let (l_type, l_op, l_const, r_const) = Type::parse_partial_constant_expr(lhs, rhs)?; + + match (l_op, op) { + (Addition | Subtraction, Addition | Subtraction) => { + // If l_op is a subtraction we want to inverse the rhs operator. + if l_op == Subtraction { + op = op.inverse()?; + } + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + (Multiplication | Division, Multiplication | Division) => { + // If l_op is a division we want to inverse the rhs operator. + if l_op == Division { + op = op.inverse()?; + } + // If op is a division we need to ensure it divides evenly + if op == Division && (r_const == 0 || l_const % r_const != 0) { + None + } else { + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + } + _ => None, + } + } +} diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index edb831b2158..92e15eea2fc 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -301,6 +301,7 @@ impl<'interner> Monomorphizer<'interner> { } let meta = self.interner.function_meta(&f).clone(); + let mut func_sig = meta.function_signature(); // Follow the bindings of the function signature for entry points // which are not `main` such as foldable functions. @@ -1950,6 +1951,7 @@ pub fn resolve_trait_method( TraitImplKind::Normal(impl_id) => impl_id, TraitImplKind::Assumed { object_type, trait_generics } => { let location = interner.expr_location(&expr_id); + match interner.lookup_trait_implementation( &object_type, method.trait_id,