Skip to content

Commit

Permalink
fix: "Types in a binary operation should match, but found T and T" (#…
Browse files Browse the repository at this point in the history
…4648)

# Description

## Problem\*

Resolves #4635 
Resolves #4502

## Summary\*

This was more difficult to fix than it originally seemed. The main issue
was between interactions with unbound type variables, type aliases, type
rules for operators, and operator traits.

Removing the "infer unbound type variables to be numeric" rule on
operators causes a lot of stdlib code to break where it'd be
unreasonable to have type annotations. This caused unbound type
variables to be bound to the first object type whose impl it was checked
against when calling verify trait impl.

I eventually settled on just delaying the verify trait impl check for
operators until the end of a function when more types are known.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Mar 29, 2024
1 parent 05b32fc commit 30c9f31
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 111 deletions.
7 changes: 2 additions & 5 deletions compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ pub enum TypeCheckError {
VariableMustBeMutable { name: String, span: Span },
#[error("No method named '{method_name}' found for type '{object_type}'")]
UnresolvedMethodCall { method_name: String, object_type: Type, span: Span },
#[error("Comparisons are invalid on Field types. Try casting the operands to a sized integer type first")]
InvalidComparisonOnField { span: Span },
#[error("Integers must have the same signedness LHS is {sign_x:?}, RHS is {sign_y:?}")]
IntegerSignedness { sign_x: Signedness, sign_y: Signedness, span: Span },
#[error("Integers must have the same bit width LHS is {bit_width_x}, RHS is {bit_width_y}")]
Expand All @@ -76,7 +74,7 @@ pub enum TypeCheckError {
#[error("{kind} cannot be used in a unary operation")]
InvalidUnaryOp { kind: String, span: Span },
#[error("Bitwise operations are invalid on Field types. Try casting the operands to a sized integer type first.")]
InvalidBitwiseOperationOnField { span: Span },
FieldBitwiseOp { span: Span },
#[error("Integer cannot be used with type {typ}")]
IntegerTypeMismatch { typ: Type, span: Span },
#[error("Cannot use an integer and a Field in a binary operation, try converting the Field into an integer first")]
Expand Down Expand Up @@ -224,12 +222,11 @@ impl From<TypeCheckError> for Diagnostic {
| TypeCheckError::TupleIndexOutOfBounds { span, .. }
| TypeCheckError::VariableMustBeMutable { span, .. }
| TypeCheckError::UnresolvedMethodCall { span, .. }
| TypeCheckError::InvalidComparisonOnField { span }
| TypeCheckError::IntegerSignedness { span, .. }
| TypeCheckError::IntegerBitWidth { span, .. }
| TypeCheckError::InvalidInfixOp { span, .. }
| TypeCheckError::InvalidUnaryOp { span, .. }
| TypeCheckError::InvalidBitwiseOperationOnField { span, .. }
| TypeCheckError::FieldBitwiseOp { span, .. }
| TypeCheckError::IntegerTypeMismatch { span, .. }
| TypeCheckError::FieldComparison { span, .. }
| TypeCheckError::AmbiguousBitWidth { span, .. }
Expand Down
85 changes: 46 additions & 39 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,16 @@ impl<'interner> TypeChecker<'interner> {
Ok((typ, use_impl)) => {
if use_impl {
let id = infix_expr.trait_method_id;
// Assume operators have no trait generics
self.verify_trait_constraint(
&lhs_type,
id.trait_id,
&[],
*expr_id,
span,
);

// Delay checking the trait constraint until the end of the function.
// Checking it now could bind an unbound type variable to any type
// that implements the trait.
let constraint = crate::hir_def::traits::TraitConstraint {
typ: lhs_type.clone(),
trait_id: id.trait_id,
trait_generics: Vec::new(),
};
self.trait_constraints.push((constraint, *expr_id));
self.typecheck_operator_method(*expr_id, id, &lhs_type, span);
}
typ
Expand Down Expand Up @@ -836,6 +838,10 @@ impl<'interner> TypeChecker<'interner> {
match (lhs_type, rhs_type) {
// Avoid reporting errors multiple times
(Error, _) | (_, Error) => Ok((Bool, false)),
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
}

// Matches on TypeVariable must be first to follow any type
// bindings.
Expand All @@ -844,12 +850,8 @@ impl<'interner> TypeChecker<'interner> {
return self.comparator_operand_type_rules(other, binding, op, span);
}

self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((Bool, false))
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((Bool, use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
Expand Down Expand Up @@ -1079,36 +1081,43 @@ impl<'interner> TypeChecker<'interner> {
}
}

/// Handles the TypeVariable case for checking binary operators.
/// Returns true if we should use the impl for the operator instead of the primitive
/// version of it.
fn bind_type_variables_for_infix(
&mut self,
lhs_type: &Type,
op: &HirBinaryOp,
rhs_type: &Type,
span: Span,
) {
) -> bool {
self.unify(lhs_type, rhs_type, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::Binary,
span,
});

// In addition to unifying both types, we also have to bind either
// the lhs or rhs to an integer type variable. This ensures if both lhs
// and rhs are type variables, that they will have the correct integer
// type variable kind instead of TypeVariableKind::Normal.
let target = if op.kind.is_valid_for_field_type() {
Type::polymorphic_integer_or_field(self.interner)
} else {
Type::polymorphic_integer(self.interner)
};
let use_impl = !lhs_type.is_numeric();

// If this operator isn't valid for fields we have to possibly narrow
// TypeVariableKind::IntegerOrField to TypeVariableKind::Integer.
// Doing so also ensures a type error if Field is used.
// The is_numeric check is to allow impls for custom types to bypass this.
if !op.kind.is_valid_for_field_type() && lhs_type.is_numeric() {
let target = Type::polymorphic_integer(self.interner);

use BinaryOpKind::*;
use TypeCheckError::*;
self.unify(lhs_type, &target, || match op.kind {
Less | LessEqual | Greater | GreaterEqual => FieldComparison { span },
And | Or | Xor | ShiftRight | ShiftLeft => FieldBitwiseOp { span },
Modulo => FieldModulo { span },
other => unreachable!("Operator {other:?} should be valid for Field"),
});
}

self.unify(lhs_type, &target, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::Binary,
span,
});
use_impl
}

// Given a binary operator and another type. This method will produce the output type
Expand All @@ -1130,6 +1139,10 @@ impl<'interner> TypeChecker<'interner> {
match (lhs_type, rhs_type) {
// An error type on either side will always return an error
(Error, _) | (_, Error) => Ok((Error, false)),
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
}

// Matches on TypeVariable must be first so that we follow any type
// bindings.
Expand All @@ -1138,14 +1151,8 @@ impl<'interner> TypeChecker<'interner> {
return self.infix_operand_type_rules(binding, op, other, span);
}

self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);

// Both types are unified so the choice of which to return is arbitrary
Ok((other.clone(), false))
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((other.clone(), use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
Expand All @@ -1170,7 +1177,7 @@ impl<'interner> TypeChecker<'interner> {
if op.kind == BinaryOpKind::Modulo {
return Err(TypeCheckError::FieldModulo { span });
} else {
return Err(TypeCheckError::InvalidBitwiseOperationOnField { span });
return Err(TypeCheckError::FieldBitwiseOp { span });
}
}
Ok((FieldElement, false))
Expand Down
50 changes: 24 additions & 26 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,13 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type

let function_last_type = type_checker.check_function_body(function_body_id);

// Verify any remaining trait constraints arising from the function body
for (constraint, expr_id) in std::mem::take(&mut type_checker.trait_constraints) {
let span = type_checker.interner.expr_span(&expr_id);
type_checker.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}

errors.append(&mut type_checker.errors);

// Now remove all the `where` clause constraints we added
for constraint in &expected_trait_constraints {
interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
}

// Check declared return type and actual return type
if !can_ignore_ret {
let (expr_span, empty_function) = function_info(interner, function_body_id);
let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
let (expr_span, empty_function) = function_info(type_checker.interner, function_body_id);
let func_span = type_checker.interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
if let Type::TraitAsType(trait_id, _, generics) = &declared_return_type {
if interner
if type_checker
.interner
.lookup_trait_implementation(&function_last_type, *trait_id, generics)
.is_err()
{
Expand All @@ -126,7 +108,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
type_checker.interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
Expand All @@ -137,16 +119,32 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
};

if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
error = error.add_context("implicitly returns `()` as its body has no tail or `return` expression");
}
error
},
);
}
}

// Verify any remaining trait constraints arising from the function body
for (constraint, expr_id) in std::mem::take(&mut type_checker.trait_constraints) {
let span = type_checker.interner.expr_span(&expr_id);
type_checker.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}

// Now remove all the `where` clause constraints we added
for constraint in &expected_trait_constraints {
type_checker.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id);
}

errors.append(&mut type_checker.errors);
errors
}

Expand Down
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ impl Type {
TypeBinding::Bound(binding) => binding.is_bindable(),
TypeBinding::Unbound(_) => true,
},
Type::Alias(alias, args) => alias.borrow().get_type(args).is_bindable(),
_ => false,
}
}
Expand All @@ -605,6 +606,15 @@ impl Type {
matches!(self.follow_bindings(), Type::Integer(Signedness::Unsigned, _))
}

pub fn is_numeric(&self) -> bool {
use Type::*;
use TypeVariableKind as K;
matches!(
self.follow_bindings(),
FieldElement | Integer(..) | Bool | TypeVariable(_, K::Integer | K::IntegerOrField)
)
}

fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool {
// True if the given type is a NamedGeneric with the target_id
let named_generic_id_matches_target = |typ: &Type| {
Expand Down
12 changes: 6 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,19 +1033,19 @@ mod test {
fn resolve_complex_closures() {
let src = r#"
fn main(x: Field) -> pub Field {
let closure_without_captures = |x| x + x;
let closure_without_captures = |x: Field| -> Field { x + x };
let a = closure_without_captures(1);
let closure_capturing_a_param = |y| y + x;
let closure_capturing_a_param = |y: Field| -> Field { y + x };
let b = closure_capturing_a_param(2);
let closure_capturing_a_local_var = |y| y + b;
let closure_capturing_a_local_var = |y: Field| -> Field { y + b };
let c = closure_capturing_a_local_var(3);
let closure_with_transitive_captures = |y| {
let closure_with_transitive_captures = |y: Field| -> Field {
let d = 5;
let nested_closure = |z| {
let doubly_nested_closure = |w| w + x + b;
let nested_closure = |z: Field| -> Field {
let doubly_nested_closure = |w: Field| -> Field { w + x + b };
a + z + y + d + x + doubly_nested_closure(4) + x + y
};
let res = nested_closure(5);
Expand Down
14 changes: 7 additions & 7 deletions noir_stdlib/src/cmp.nr
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ trait Eq {

impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } }

impl Eq for u1 { fn eq(self, other: u1) -> bool { self == other } }
impl Eq for u8 { fn eq(self, other: u8) -> bool { self == other } }
impl Eq for u32 { fn eq(self, other: u32) -> bool { self == other } }
impl Eq for u64 { fn eq(self, other: u64) -> bool { self == other } }
impl Eq for u32 { fn eq(self, other: u32) -> bool { self == other } }
impl Eq for u8 { fn eq(self, other: u8) -> bool { self == other } }
impl Eq for u1 { fn eq(self, other: u1) -> bool { self == other } }

impl Eq for i8 { fn eq(self, other: i8) -> bool { self == other } }
impl Eq for i32 { fn eq(self, other: i32) -> bool { self == other } }
Expand Down Expand Up @@ -107,8 +107,8 @@ trait Ord {

// Note: Field deliberately does not implement Ord

impl Ord for u8 {
fn cmp(self, other: u8) -> Ordering {
impl Ord for u64 {
fn cmp(self, other: u64) -> Ordering {
if self < other {
Ordering::less()
} else if self > other {
Expand All @@ -131,8 +131,8 @@ impl Ord for u32 {
}
}

impl Ord for u64 {
fn cmp(self, other: u64) -> Ordering {
impl Ord for u8 {
fn cmp(self, other: u8) -> Ordering {
if self < other {
Ordering::less()
} else if self > other {
Expand Down
Loading

0 comments on commit 30c9f31

Please sign in to comment.