diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 0da39edfd85..94b5841e52c 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -565,7 +565,7 @@ impl ForRange { identifier: Ident, block: Expression, for_loop_span: Span, - ) -> StatementKind { + ) -> Statement { /// Counter used to generate unique names when desugaring /// code in the parser requires the creation of fresh variables. /// The parser is stateless so this is a static global instead. @@ -662,7 +662,8 @@ impl ForRange { let block = ExpressionKind::Block(BlockExpression { statements: vec![let_array, for_loop], }); - StatementKind::Expression(Expression::new(block, for_loop_span)) + let kind = StatementKind::Expression(Expression::new(block, for_loop_span)); + Statement { kind, span: for_loop_span } } } } diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs new file mode 100644 index 00000000000..19ee67b442c --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -0,0 +1,604 @@ +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use regex::Regex; +use rustc_hash::FxHashSet as HashSet; + +use crate::{ + ast::{ + ArrayLiteral, ConstructorExpression, IfExpression, InfixExpression, Lambda, + UnresolvedTypeExpression, + }, + hir::{ + resolution::{errors::ResolverError, resolver::LambdaContext}, + type_check::TypeCheckError, + }, + hir_def::{ + expr::{ + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirMethodReference, HirPrefixExpression, + }, + traits::TraitConstraint, + }, + macros_api::{ + BlockExpression, CallExpression, CastExpression, Expression, ExpressionKind, HirExpression, + HirLiteral, HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression, + MethodCallExpression, PrefixExpression, + }, + node_interner::{DefinitionKind, ExprId, FuncId}, + Shared, StructType, Type, +}; + +use super::Elaborator; + +impl Elaborator { + pub(super) fn elaborate_expression(&mut self, expr: Expression) -> (ExprId, Type) { + let (hir_expr, typ) = match expr.kind { + ExpressionKind::Literal(literal) => self.elaborate_literal(literal, expr.span), + ExpressionKind::Block(block) => self.elaborate_block(block), + ExpressionKind::Prefix(prefix) => self.elaborate_prefix(*prefix), + ExpressionKind::Index(index) => self.elaborate_index(*index), + ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span), + ExpressionKind::MethodCall(call) => self.elaborate_method_call(*call, expr.span), + ExpressionKind::Constructor(constructor) => self.elaborate_constructor(*constructor), + ExpressionKind::MemberAccess(access) => { + return self.elaborate_member_access(*access, expr.span) + } + ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), + ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), + ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), + ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), + ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda), + ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), + ExpressionKind::Quote(quote) => self.elaborate_quote(quote), + ExpressionKind::Comptime(comptime) => self.elaborate_comptime_block(comptime), + ExpressionKind::Error => (HirExpression::Error, Type::Error), + }; + let id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(id, expr.span, self.file); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + + pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) { + self.push_scope(); + let mut block_type = Type::Unit; + let mut statements = Vec::with_capacity(block.statements.len()); + + for (i, statement) in block.statements.into_iter().enumerate() { + let (id, stmt_type) = self.elaborate_statement(statement); + statements.push(id); + + if let HirStatement::Semi(expr) = self.interner.statement(&id) { + let inner_expr_type = self.interner.id_type(expr); + let span = self.interner.expr_span(&expr); + + self.unify(&inner_expr_type, &Type::Unit, || TypeCheckError::UnusedResultError { + expr_type: inner_expr_type.clone(), + expr_span: span, + }); + + if i + 1 == statements.len() { + block_type = stmt_type; + } + } + } + + self.pop_scope(); + (HirExpression::Block(HirBlockExpression { statements }), block_type) + } + + fn elaborate_literal(&mut self, literal: Literal, span: Span) -> (HirExpression, Type) { + use HirExpression::Literal as Lit; + match literal { + Literal::Unit => (Lit(HirLiteral::Unit), Type::Unit), + Literal::Bool(b) => (Lit(HirLiteral::Bool(b)), Type::Bool), + Literal::Integer(integer, sign) => { + let int = HirLiteral::Integer(integer, sign); + (Lit(int), self.polymorphic_integer_or_field()) + } + Literal::Str(str) | Literal::RawStr(str, _) => { + let len = Type::Constant(str.len() as u64); + (Lit(HirLiteral::Str(str)), Type::String(Box::new(len))) + } + Literal::FmtStr(str) => self.elaborate_fmt_string(str, span), + Literal::Array(array_literal) => { + self.elaborate_array_literal(array_literal, span, true) + } + Literal::Slice(array_literal) => { + self.elaborate_array_literal(array_literal, span, false) + } + } + } + + fn elaborate_array_literal( + &mut self, + array_literal: ArrayLiteral, + span: Span, + is_array: bool, + ) -> (HirExpression, Type) { + let (expr, elem_type, length) = match array_literal { + ArrayLiteral::Standard(elements) => { + let first_elem_type = self.interner.next_type_variable(); + let first_span = elements.first().map(|elem| elem.span).unwrap_or(span); + + let elements = vecmap(elements.into_iter().enumerate(), |(i, elem)| { + let span = elem.span; + let (elem_id, elem_type) = self.elaborate_expression(elem); + + self.unify(&elem_type, &first_elem_type, || { + TypeCheckError::NonHomogeneousArray { + first_span, + first_type: first_elem_type.to_string(), + first_index: 0, + second_span: span, + second_type: elem_type.to_string(), + second_index: i, + } + .add_context("elements in an array must have the same type") + }); + elem_id + }); + + let length = Type::Constant(elements.len() as u64); + (HirArrayLiteral::Standard(elements), first_elem_type, length) + } + ArrayLiteral::Repeated { repeated_element, length } => { + let span = length.span; + let length = + UnresolvedTypeExpression::from_expr(*length, span).unwrap_or_else(|error| { + self.push_err(ResolverError::ParserError(Box::new(error))); + UnresolvedTypeExpression::Constant(0, span) + }); + + let length = self.convert_expression_type(length); + let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element); + + let length_clone = length.clone(); + (HirArrayLiteral::Repeated { repeated_element, length }, elem_type, length_clone) + } + }; + let constructor = if is_array { HirLiteral::Array } else { HirLiteral::Slice }; + let elem_type = Box::new(elem_type); + let typ = if is_array { + Type::Array(Box::new(length), elem_type) + } else { + Type::Slice(elem_type) + }; + (HirExpression::Literal(constructor(expr)), typ) + } + + fn elaborate_fmt_string(&mut self, str: String, call_expr_span: Span) -> (HirExpression, Type) { + let re = Regex::new(r"\{([a-zA-Z0-9_]+)\}") + .expect("ICE: an invalid regex pattern was used for checking format strings"); + + let mut fmt_str_idents = Vec::new(); + let mut capture_types = Vec::new(); + + for field in re.find_iter(&str) { + let matched_str = field.as_str(); + let ident_name = &matched_str[1..(matched_str.len() - 1)]; + + let scope_tree = self.scopes.current_scope_tree(); + let variable = scope_tree.find(ident_name); + if let Some((old_value, _)) = variable { + old_value.num_times_used += 1; + let ident = HirExpression::Ident(old_value.ident.clone()); + let expr_id = self.interner.push_expr(ident); + self.interner.push_expr_location(expr_id, call_expr_span, self.file); + let ident = old_value.ident.clone(); + let typ = self.type_check_variable(ident, expr_id); + self.interner.push_expr_type(expr_id, typ.clone()); + capture_types.push(typ); + fmt_str_idents.push(expr_id); + } else if ident_name.parse::().is_ok() { + self.push_err(ResolverError::NumericConstantInFormatString { + name: ident_name.to_owned(), + span: call_expr_span, + }); + } else { + self.push_err(ResolverError::VariableNotDeclared { + name: ident_name.to_owned(), + span: call_expr_span, + }); + } + } + + let len = Type::Constant(str.len() as u64); + let typ = Type::FmtString(Box::new(len), Box::new(Type::Tuple(capture_types))); + (HirExpression::Literal(HirLiteral::FmtStr(str, fmt_str_idents)), typ) + } + + fn elaborate_prefix(&mut self, prefix: PrefixExpression) -> (HirExpression, Type) { + let span = prefix.rhs.span; + let (rhs, rhs_type) = self.elaborate_expression(prefix.rhs); + let ret_type = self.type_check_prefix_operand(&prefix.operator, &rhs_type, span); + (HirExpression::Prefix(HirPrefixExpression { operator: prefix.operator, rhs }), ret_type) + } + + fn elaborate_index(&mut self, index_expr: IndexExpression) -> (HirExpression, Type) { + let span = index_expr.index.span; + let (index, index_type) = self.elaborate_expression(index_expr.index); + + let expected = self.polymorphic_integer_or_field(); + self.unify(&index_type, &expected, || TypeCheckError::TypeMismatch { + expected_typ: "an integer".to_owned(), + expr_typ: index_type.to_string(), + expr_span: span, + }); + + // When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many + // times as needed to get the underlying array. + let lhs_span = index_expr.collection.span; + let (lhs, lhs_type) = self.elaborate_expression(index_expr.collection); + let (collection, lhs_type) = self.insert_auto_dereferences(lhs, lhs_type); + + let typ = match lhs_type.follow_bindings() { + // XXX: We can check the array bounds here also, but it may be better to constant fold first + // and have ConstId instead of ExprId for constants + Type::Array(_, base_type) => *base_type, + Type::Slice(base_type) => *base_type, + Type::Error => Type::Error, + typ => { + self.push_err(TypeCheckError::TypeMismatch { + expected_typ: "Array".to_owned(), + expr_typ: typ.to_string(), + expr_span: lhs_span, + }); + Type::Error + } + }; + + let expr = HirExpression::Index(HirIndexExpression { collection, index }); + (expr, typ) + } + + fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { + let (func, func_type) = self.elaborate_expression(*call.func); + + let mut arguments = Vec::with_capacity(call.arguments.len()); + let args = vecmap(call.arguments, |arg| { + let span = arg.span; + let (arg, typ) = self.elaborate_expression(arg); + arguments.push(arg); + (typ, arg, span) + }); + + let location = Location::new(span, self.file); + let call = HirCallExpression { func, arguments, location }; + let typ = self.type_check_call(&call, func_type, args, span); + (HirExpression::Call(call), typ) + } + + fn elaborate_method_call( + &mut self, + method_call: MethodCallExpression, + span: Span, + ) -> (HirExpression, Type) { + let object_span = method_call.object.span; + let (mut object, mut object_type) = self.elaborate_expression(method_call.object); + object_type = object_type.follow_bindings(); + + let method_name = method_call.method_name.0.contents.as_str(); + match self.lookup_method(&object_type, method_name, span) { + Some(method_ref) => { + // Automatically add `&mut` if the method expects a mutable reference and + // the object is not already one. + if let HirMethodReference::FuncId(func_id) = &method_ref { + if *func_id != FuncId::dummy_id() { + let function_type = self.interner.function_meta(func_id).typ.clone(); + + self.try_add_mutable_reference_to_object( + &function_type, + &mut object_type, + &mut object, + ); + } + } + + // These arguments will be given to the desugared function call. + // Compared to the method arguments, they also contain the object. + let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1); + let mut arguments = Vec::with_capacity(method_call.arguments.len()); + + function_args.push((object_type.clone(), object, object_span)); + + for arg in method_call.arguments { + let span = arg.span; + let (arg, typ) = self.elaborate_expression(arg); + arguments.push(arg); + function_args.push((typ, arg, span)); + } + + let location = Location::new(span, self.file); + let method = method_call.method_name; + let method_call = HirMethodCallExpression { method, object, arguments, location }; + + // Desugar the method call into a normal, resolved function call + // so that the backend doesn't need to worry about methods + // TODO: update object_type here? + let ((function_id, function_name), function_call) = method_call.into_function_call( + &method_ref, + object_type, + location, + &mut self.interner, + ); + + let func_type = self.type_check_variable(function_name, function_id); + + // Type check the new call now that it has been changed from a method call + // to a function call. This way we avoid duplicating code. + let typ = self.type_check_call(&function_call, func_type, function_args, span); + (HirExpression::Call(function_call), typ) + } + None => (HirExpression::Error, Type::Error), + } + } + + fn elaborate_constructor( + &mut self, + constructor: ConstructorExpression, + ) -> (HirExpression, Type) { + let span = constructor.type_name.span(); + + match self.lookup_type_or_error(constructor.type_name) { + Some(Type::Struct(r#type, struct_generics)) => { + let struct_type = r#type.clone(); + let generics = struct_generics.clone(); + + let fields = constructor.fields; + let field_types = r#type.borrow().get_fields(&struct_generics); + let fields = self.resolve_constructor_expr_fields( + struct_type.clone(), + field_types, + fields, + span, + ); + let expr = HirExpression::Constructor(HirConstructorExpression { + fields, + r#type, + struct_generics, + }); + (expr, Type::Struct(struct_type, generics)) + } + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + (HirExpression::Error, Type::Error) + } + None => (HirExpression::Error, Type::Error), + } + } + + /// Resolve all the fields of a struct constructor expression. + /// Ensures all fields are present, none are repeated, and all + /// are part of the struct. + fn resolve_constructor_expr_fields( + &mut self, + struct_type: Shared, + field_types: Vec<(String, Type)>, + fields: Vec<(Ident, Expression)>, + span: Span, + ) -> Vec<(Ident, ExprId)> { + let mut ret = Vec::with_capacity(fields.len()); + let mut seen_fields = HashSet::default(); + let mut unseen_fields = struct_type.borrow().field_names(); + + for (field_name, field) in fields { + let expected_type = field_types.iter().find(|(name, _)| name == &field_name.0.contents); + let expected_type = expected_type.map(|(_, typ)| typ).unwrap_or(&Type::Error); + + let field_span = field.span; + let (resolved, field_type) = self.elaborate_expression(field); + + if unseen_fields.contains(&field_name) { + unseen_fields.remove(&field_name); + seen_fields.insert(field_name.clone()); + + self.unify_with_coercions(&field_type, expected_type, resolved, || { + TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: field_type.to_string(), + expr_span: field_span, + } + }); + } else if seen_fields.contains(&field_name) { + // duplicate field + self.push_err(ResolverError::DuplicateField { field: field_name.clone() }); + } else { + // field not required by struct + self.push_err(ResolverError::NoSuchField { + field: field_name.clone(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret.push((field_name, resolved)); + } + + if !unseen_fields.is_empty() { + self.push_err(ResolverError::MissingFields { + span, + missing_fields: unseen_fields.into_iter().map(|field| field.to_string()).collect(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret + } + + fn elaborate_member_access( + &mut self, + access: MemberAccessExpression, + span: Span, + ) -> (ExprId, Type) { + let (lhs, lhs_type) = self.elaborate_expression(access.lhs); + let rhs = access.rhs; + // `is_offset` is only used when lhs is a reference and we want to return a reference to rhs + let access = HirMemberAccess { lhs, rhs, is_offset: false }; + let expr_id = self.intern_expr(HirExpression::MemberAccess(access.clone()), span); + let typ = self.type_check_member_access(access, expr_id, lhs_type, span); + self.interner.push_expr_type(expr_id, typ.clone()); + (expr_id, typ) + } + + fn intern_expr(&mut self, expr: HirExpression, span: Span) -> ExprId { + let id = self.interner.push_expr(expr); + self.interner.push_expr_location(id, span, self.file); + id + } + + fn elaborate_cast(&mut self, cast: CastExpression, span: Span) -> (HirExpression, Type) { + let (lhs, lhs_type) = self.elaborate_expression(cast.lhs); + let r#type = self.resolve_type(cast.r#type); + let result = self.check_cast(lhs_type, &r#type, span); + let expr = HirExpression::Cast(HirCastExpression { lhs, r#type }); + (expr, result) + } + + fn elaborate_infix(&mut self, infix: InfixExpression, span: Span) -> (ExprId, Type) { + let (lhs, lhs_type) = self.elaborate_expression(infix.lhs); + let (rhs, rhs_type) = self.elaborate_expression(infix.rhs); + let trait_id = self.interner.get_operator_trait_method(infix.operator.contents); + + let operator = HirBinaryOp::new(infix.operator, self.file); + let expr = HirExpression::Infix(HirInfixExpression { + lhs, + operator, + trait_method_id: trait_id, + rhs, + }); + + let expr_id = self.interner.push_expr(expr); + self.interner.push_expr_location(expr_id, span, self.file); + + let typ = match self.infix_operand_type_rules(&lhs_type, &operator, &rhs_type, span) { + Ok((typ, use_impl)) => { + if use_impl { + // Delay checking the trait constraint until the end of the function. + // Checking it now could bind an unbound type variable to any type + // that implements the trait. + let constraint = TraitConstraint { + typ: lhs_type.clone(), + trait_id: trait_id.trait_id, + trait_generics: Vec::new(), + }; + self.trait_constraints.push((constraint, expr_id)); + self.type_check_operator_method(expr_id, trait_id, &lhs_type, span); + } + typ + } + Err(error) => { + self.push_err(error); + Type::Error + } + }; + + self.interner.push_expr_type(expr_id, typ.clone()); + (expr_id, typ) + } + + fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) { + let expr_span = if_expr.condition.span; + let (condition, cond_type) = self.elaborate_expression(if_expr.condition); + let (consequence, mut ret_type) = self.elaborate_expression(if_expr.consequence); + + self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expected_typ: Type::Bool.to_string(), + expr_typ: cond_type.to_string(), + expr_span, + }); + + let alternative = if_expr.alternative.map(|alternative| { + let expr_span = alternative.span; + let (else_, else_type) = self.elaborate_expression(alternative); + + self.unify(&ret_type, &else_type, || { + let err = TypeCheckError::TypeMismatch { + expected_typ: ret_type.to_string(), + expr_typ: else_type.to_string(), + expr_span, + }; + + let context = if ret_type == Type::Unit { + "Are you missing a semicolon at the end of your 'else' branch?" + } else if else_type == Type::Unit { + "Are you missing a semicolon at the end of the first block of this 'if'?" + } else { + "Expected the types of both if branches to be equal" + }; + + err.add_context(context) + }); + else_ + }); + + if alternative.is_none() { + ret_type = Type::Unit; + } + + let if_expr = HirIfExpression { condition, consequence, alternative }; + (HirExpression::If(if_expr), ret_type) + } + + fn elaborate_tuple(&mut self, tuple: Vec) -> (HirExpression, Type) { + let mut element_ids = Vec::with_capacity(tuple.len()); + let mut element_types = Vec::with_capacity(tuple.len()); + + for element in tuple { + let (id, typ) = self.elaborate_expression(element); + element_ids.push(id); + element_types.push(typ); + } + + (HirExpression::Tuple(element_ids), Type::Tuple(element_types)) + } + + fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) { + self.push_scope(); + let scope_index = self.scopes.current_scope_index(); + + self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); + + let mut arg_types = Vec::with_capacity(lambda.parameters.len()); + let parameters = vecmap(lambda.parameters, |(pattern, typ)| { + let parameter = DefinitionKind::Local(None); + let typ = self.resolve_inferred_type(typ); + arg_types.push(typ.clone()); + (self.elaborate_pattern(pattern, typ.clone(), parameter), typ) + }); + + let return_type = self.resolve_inferred_type(lambda.return_type); + let body_span = lambda.body.span; + let (body, body_type) = self.elaborate_expression(lambda.body); + + let lambda_context = self.lambda_stack.pop().unwrap(); + self.pop_scope(); + + self.unify(&body_type, &return_type, || TypeCheckError::TypeMismatch { + expected_typ: return_type.to_string(), + expr_typ: body_type.to_string(), + expr_span: body_span, + }); + + let captured_vars = vecmap(&lambda_context.captures, |capture| { + self.interner.definition_type(capture.ident.id) + }); + + let env_type = + if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; + + let captures = lambda_context.captures; + let expr = HirExpression::Lambda(HirLambda { parameters, return_type, body, captures }); + (expr, Type::Function(arg_types, Box::new(body_type), Box::new(env_type))) + } + + fn elaborate_quote(&mut self, block: BlockExpression) -> (HirExpression, Type) { + (HirExpression::Quote(block), Type::Code) + } + + fn elaborate_comptime_block(&mut self, _comptime: BlockExpression) -> (HirExpression, Type) { + todo!("Elaborate comptime block") + } +} diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs new file mode 100644 index 00000000000..9a34fa847f5 --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -0,0 +1,159 @@ +#![allow(unused)] +use std::{ + collections::{BTreeMap, BTreeSet}, + rc::Rc, +}; + +use crate::graph::CrateId; +use crate::hir::def_map::CrateDefMap; +use crate::{ + ast::{ + ArrayLiteral, ConstructorExpression, FunctionKind, IfExpression, InfixExpression, Lambda, + UnresolvedTraitConstraint, UnresolvedTypeExpression, + }, + hir::{ + def_collector::dc_crate::CompilationError, + resolution::{errors::ResolverError, path_resolver::PathResolver, resolver::LambdaContext}, + scope::ScopeForest as GenericScopeForest, + type_check::TypeCheckError, + }, + hir_def::{ + expr::{ + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirMethodReference, HirPrefixExpression, + }, + traits::TraitConstraint, + }, + macros_api::{ + BlockExpression, CallExpression, CastExpression, Expression, ExpressionKind, HirExpression, + HirLiteral, HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression, + MethodCallExpression, NodeInterner, NoirFunction, PrefixExpression, Statement, + StatementKind, StructId, + }, + node_interner::{DefinitionKind, DependencyId, ExprId, FuncId, StmtId, TraitId}, + Shared, StructType, Type, TypeVariable, +}; + +mod expressions; +mod patterns; +mod scope; +mod statements; +mod types; + +use fm::FileId; +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use regex::Regex; +use rustc_hash::FxHashSet as HashSet; + +/// ResolverMetas are tagged onto each definition to track how many times they are used +#[derive(Debug, PartialEq, Eq)] +struct ResolverMeta { + num_times_used: usize, + ident: HirIdent, + warn_if_unused: bool, +} + +type ScopeForest = GenericScopeForest; + +struct Elaborator { + scopes: ScopeForest, + + errors: Vec, + + interner: NodeInterner, + file: FileId, + + in_unconstrained_fn: bool, + nested_loops: usize, + + /// True if the current module is a contract. + /// This is usually determined by self.path_resolver.module_id(), but it can + /// be overridden for impls. Impls are an odd case since the methods within resolve + /// as if they're in the parent module, but should be placed in a child module. + /// Since they should be within a child module, in_contract is manually set to false + /// for these so we can still resolve them in the parent module without them being in a contract. + in_contract: bool, + + /// Contains a mapping of the current struct or functions's generics to + /// unique type variables if we're resolving a struct. Empty otherwise. + /// This is a Vec rather than a map to preserve the order a functions generics + /// were declared in. + generics: Vec<(Rc, TypeVariable, Span)>, + + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, + + /// Set to the current type if we're resolving an impl + self_type: Option, + + /// The current dependency item we're resolving. + /// Used to link items to their dependencies in the dependency graph + current_item: Option, + + trait_id: Option, + + path_resolver: Rc, + def_maps: BTreeMap, + + /// In-resolution names + /// + /// This needs to be a set because we can have multiple in-resolution + /// names when resolving structs that are declared in reverse order of their + /// dependencies, such as in the following case: + /// + /// ``` + /// struct Wrapper { + /// value: Wrapped + /// } + /// struct Wrapped { + /// } + /// ``` + resolving_ids: BTreeSet, + + trait_bounds: Vec, + + current_function: Option, + + /// All type variables created in the current function. + /// This map is used to default any integer type variables at the end of + /// a function (before checking trait constraints) if a type wasn't already chosen. + type_variables: Vec, + + /// Trait constraints are collected during type checking until they are + /// verified at the end of a function. This is because constraints arise + /// on each variable, but it is only until function calls when the types + /// needed for the trait constraint may become known. + trait_constraints: Vec<(TraitConstraint, ExprId)>, +} + +impl Elaborator { + fn elaborate_function(&mut self, function: NoirFunction, _id: FuncId) { + // This is a stub until the elaborator is connected to dc_crate + match function.kind { + FunctionKind::LowLevel => todo!(), + FunctionKind::Builtin => todo!(), + FunctionKind::Oracle => todo!(), + FunctionKind::Recursive => todo!(), + FunctionKind::Normal => { + let _body = self.elaborate_block(function.def.body); + } + } + } + + fn push_scope(&mut self) { + // stub + } + + fn pop_scope(&mut self) { + // stub + } + + fn push_err(&mut self, error: impl Into) { + self.errors.push(error.into()); + } +} diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs new file mode 100644 index 00000000000..b51115417e7 --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -0,0 +1,463 @@ +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use rustc_hash::FxHashSet as HashSet; + +use crate::{ + ast::ERROR_IDENT, + hir::{ + resolution::errors::ResolverError, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::{HirIdent, ImplKind}, + stmt::HirPattern, + }, + macros_api::{HirExpression, Ident, Path, Pattern}, + node_interner::{DefinitionId, DefinitionKind, ExprId, TraitImplKind}, + Shared, StructType, Type, TypeBindings, +}; + +use super::{Elaborator, ResolverMeta}; + +impl Elaborator { + pub(super) fn elaborate_pattern( + &mut self, + pattern: Pattern, + expected_type: Type, + definition_kind: DefinitionKind, + ) -> HirPattern { + self.elaborate_pattern_mut(pattern, expected_type, definition_kind, None) + } + + fn elaborate_pattern_mut( + &mut self, + pattern: Pattern, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> HirPattern { + match pattern { + Pattern::Identifier(name) => { + // If this definition is mutable, do not store the rhs because it will + // not always refer to the correct value of the variable + let definition = match (mutable, definition) { + (Some(_), DefinitionKind::Local(_)) => DefinitionKind::Local(None), + (_, other) => other, + }; + let ident = self.add_variable_decl(name, mutable.is_some(), true, definition); + self.interner.push_definition_type(ident.id, expected_type); + HirPattern::Identifier(ident) + } + Pattern::Mutable(pattern, span, _) => { + if let Some(first_mut) = mutable { + self.push_err(ResolverError::UnnecessaryMut { first_mut, second_mut: span }); + } + + let pattern = + self.elaborate_pattern_mut(*pattern, expected_type, definition, Some(span)); + let location = Location::new(span, self.file); + HirPattern::Mutable(Box::new(pattern), location) + } + Pattern::Tuple(fields, span) => { + let field_types = match expected_type { + Type::Tuple(fields) => fields, + Type::Error => Vec::new(), + expected_type => { + let tuple = + Type::Tuple(vecmap(&fields, |_| self.interner.next_type_variable())); + + self.push_err(TypeCheckError::TypeMismatchWithSource { + expected: expected_type, + actual: tuple, + span, + source: Source::Assignment, + }); + Vec::new() + } + }; + + let fields = vecmap(fields.into_iter().enumerate(), |(i, field)| { + let field_type = field_types.get(i).cloned().unwrap_or(Type::Error); + self.elaborate_pattern_mut(field, field_type, definition.clone(), mutable) + }); + let location = Location::new(span, self.file); + HirPattern::Tuple(fields, location) + } + Pattern::Struct(name, fields, span) => self.elaborate_struct_pattern( + name, + fields, + span, + expected_type, + definition, + mutable, + ), + } + } + + fn elaborate_struct_pattern( + &mut self, + name: Path, + fields: Vec<(Ident, Pattern)>, + span: Span, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> HirPattern { + let error_identifier = |this: &mut Self| { + // Must create a name here to return a HirPattern::Identifier. Allowing + // shadowing here lets us avoid further errors if we define ERROR_IDENT + // multiple times. + let name = ERROR_IDENT.into(); + let identifier = this.add_variable_decl(name, false, true, definition.clone()); + HirPattern::Identifier(identifier) + }; + + let (struct_type, generics) = match self.lookup_type_or_error(name) { + Some(Type::Struct(struct_type, generics)) => (struct_type, generics), + None => return error_identifier(self), + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + return error_identifier(self); + } + }; + + let actual_type = Type::Struct(struct_type.clone(), generics); + let location = Location::new(span, self.file); + + self.unify(&actual_type, &expected_type, || TypeCheckError::TypeMismatchWithSource { + expected: expected_type.clone(), + actual: actual_type.clone(), + span: location.span, + source: Source::Assignment, + }); + + let typ = struct_type.clone(); + let fields = self.resolve_constructor_pattern_fields( + typ, + fields, + span, + expected_type.clone(), + definition, + mutable, + ); + + HirPattern::Struct(expected_type, fields, location) + } + + /// Resolve all the fields of a struct constructor expression. + /// Ensures all fields are present, none are repeated, and all + /// are part of the struct. + fn resolve_constructor_pattern_fields( + &mut self, + struct_type: Shared, + fields: Vec<(Ident, Pattern)>, + span: Span, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> Vec<(Ident, HirPattern)> { + let mut ret = Vec::with_capacity(fields.len()); + let mut seen_fields = HashSet::default(); + let mut unseen_fields = struct_type.borrow().field_names(); + + for (field, pattern) in fields { + let field_type = expected_type.get_field_type(&field.0.contents).unwrap_or(Type::Error); + let resolved = + self.elaborate_pattern_mut(pattern, field_type, definition.clone(), mutable); + + if unseen_fields.contains(&field) { + unseen_fields.remove(&field); + seen_fields.insert(field.clone()); + } else if seen_fields.contains(&field) { + // duplicate field + self.push_err(ResolverError::DuplicateField { field: field.clone() }); + } else { + // field not required by struct + self.push_err(ResolverError::NoSuchField { + field: field.clone(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret.push((field, resolved)); + } + + if !unseen_fields.is_empty() { + self.push_err(ResolverError::MissingFields { + span, + missing_fields: unseen_fields.into_iter().map(|field| field.to_string()).collect(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret + } + + pub(super) fn add_variable_decl( + &mut self, + name: Ident, + mutable: bool, + allow_shadowing: bool, + definition: DefinitionKind, + ) -> HirIdent { + self.add_variable_decl_inner(name, mutable, allow_shadowing, true, definition) + } + + fn add_variable_decl_inner( + &mut self, + name: Ident, + mutable: bool, + allow_shadowing: bool, + warn_if_unused: bool, + definition: DefinitionKind, + ) -> HirIdent { + if definition.is_global() { + return self.add_global_variable_decl(name, definition); + } + + let location = Location::new(name.span(), self.file); + let id = + self.interner.push_definition(name.0.contents.clone(), mutable, definition, location); + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; + + let scope = self.scopes.get_mut_scope(); + let old_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + + if !allow_shadowing { + if let Some(old_value) = old_value { + self.push_err(ResolverError::DuplicateDefinition { + name: name.0.contents, + first_span: old_value.ident.location.span, + second_span: location.span, + }); + } + } + + ident + } + + fn add_global_variable_decl(&mut self, name: Ident, definition: DefinitionKind) -> HirIdent { + let scope = self.scopes.get_mut_scope(); + + // This check is necessary to maintain the same definition ids in the interner. Currently, each function uses a new resolver that has its own ScopeForest and thus global scope. + // We must first check whether an existing definition ID has been inserted as otherwise there will be multiple definitions for the same global statement. + // This leads to an error in evaluation where the wrong definition ID is selected when evaluating a statement using the global. The check below prevents this error. + let mut global_id = None; + let global = self.interner.get_all_globals(); + for global_info in global { + if global_info.ident == name + && global_info.local_id == self.path_resolver.local_module_id() + { + global_id = Some(global_info.id); + } + } + + let (ident, resolver_meta) = if let Some(id) = global_id { + let global = self.interner.get_global(id); + let hir_ident = HirIdent::non_trait_method(global.definition_id, global.location); + let ident = hir_ident.clone(); + let resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused: true }; + (hir_ident, resolver_meta) + } else { + let location = Location::new(name.span(), self.file); + let id = + self.interner.push_definition(name.0.contents.clone(), false, definition, location); + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused: true }; + (ident, resolver_meta) + }; + + let old_global_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + if let Some(old_global_value) = old_global_value { + self.push_err(ResolverError::DuplicateDefinition { + name: name.0.contents.clone(), + first_span: old_global_value.ident.location.span, + second_span: name.span(), + }); + } + ident + } + + // Checks for a variable having been declared before. + // (Variable declaration and definition cannot be separate in Noir.) + // Once the variable has been found, intern and link `name` to this definition, + // returning (the ident, the IdentId of `name`) + // + // If a variable is not found, then an error is logged and a dummy id + // is returned, for better error reporting UX + pub(super) fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { + self.use_variable(name).unwrap_or_else(|error| { + self.push_err(error); + let id = DefinitionId::dummy_id(); + let location = Location::new(name.span(), self.file); + (HirIdent::non_trait_method(id, location), 0) + }) + } + + /// Lookup and use the specified variable. + /// This will increment its use counter by one and return the variable if found. + /// If the variable is not found, an error is returned. + pub(super) fn use_variable( + &mut self, + name: &Ident, + ) -> Result<(HirIdent, usize), ResolverError> { + // Find the definition for this Ident + let scope_tree = self.scopes.current_scope_tree(); + let variable = scope_tree.find(&name.0.contents); + + let location = Location::new(name.span(), self.file); + if let Some((variable_found, scope)) = variable { + variable_found.num_times_used += 1; + let id = variable_found.ident.id; + Ok((HirIdent::non_trait_method(id, location), scope)) + } else { + Err(ResolverError::VariableNotDeclared { + name: name.0.contents.clone(), + span: name.0.span(), + }) + } + } + + pub(super) fn elaborate_variable(&mut self, variable: Path) -> (ExprId, Type) { + let span = variable.span; + let expr = self.resolve_variable(variable); + let id = self.interner.push_expr(HirExpression::Ident(expr.clone())); + self.interner.push_expr_location(id, span, self.file); + let typ = self.type_check_variable(expr, id); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + + fn resolve_variable(&mut self, path: Path) -> HirIdent { + if let Some((method, constraint, assumed)) = self.resolve_trait_generic_path(&path) { + HirIdent { + location: Location::new(path.span, self.file), + id: self.interner.trait_method_id(method), + impl_kind: ImplKind::TraitMethod(method, constraint, assumed), + } + } else { + // If the Path is being used as an Expression, then it is referring to a global from a separate module + // Otherwise, then it is referring to an Identifier + // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; + // If the expression is a singular indent, we search the resolver's current scope as normal. + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(id) => { + if let Some(current_item) = self.current_item { + self.interner.add_function_dependency(current_item, id); + } + } + DefinitionKind::Global(global_id) => { + if let Some(current_item) = self.current_item { + self.interner.add_global_dependency(current_item, global_id); + } + } + DefinitionKind::GenericType(_) => { + // Initialize numeric generics to a polymorphic integer type in case + // they're used in expressions. We must do this here since type_check_variable + // does not check definition kinds and otherwise expects parameters to + // already be typed. + if self.interner.definition_type(hir_ident.id) == Type::Error { + let typ = Type::polymorphic_integer_or_field(&mut self.interner); + self.interner.push_definition_type(hir_ident.id, typ); + } + } + DefinitionKind::Local(_) => { + // only local variables can be captured by closures. + self.resolve_local_variable(hir_ident.clone(), var_scope_index); + } + } + } + + hir_ident + } + } + + pub(super) fn type_check_variable(&mut self, ident: HirIdent, expr_id: ExprId) -> Type { + let mut bindings = TypeBindings::new(); + + // Add type bindings from any constraints that were used. + // We need to do this first since otherwise instantiating the type below + // will replace each trait generic with a fresh type variable, rather than + // the type used in the trait constraint (if it exists). See #4088. + if let ImplKind::TraitMethod(_, constraint, _) = &ident.impl_kind { + let the_trait = self.interner.get_trait(constraint.trait_id); + assert_eq!(the_trait.generics.len(), constraint.trait_generics.len()); + + for (param, arg) in the_trait.generics.iter().zip(&constraint.trait_generics) { + // Avoid binding t = t + if !arg.occurs(param.id()) { + bindings.insert(param.id(), (param.clone(), arg.clone())); + } + } + } + + // An identifiers type may be forall-quantified in the case of generic functions. + // E.g. `fn foo(t: T, field: Field) -> T` has type `forall T. fn(T, Field) -> T`. + // We must instantiate identifiers at every call site to replace this T with a new type + // variable to handle generic functions. + let t = self.interner.id_type_substitute_trait_as_type(ident.id); + + // This instantiates a trait's generics as well which need to be set + // when the constraint below is later solved for when the function is + // finished. How to link the two? + let (typ, bindings) = t.instantiate_with_bindings(bindings, &self.interner); + + // Push any trait constraints required by this definition to the context + // to be checked later when the type of this variable is further constrained. + if let Some(definition) = self.interner.try_definition(ident.id) { + if let DefinitionKind::Function(function) = definition.kind { + let function = self.interner.function_meta(&function); + + for mut constraint in function.trait_constraints.clone() { + constraint.apply_bindings(&bindings); + self.trait_constraints.push((constraint, expr_id)); + } + } + } + + if let ImplKind::TraitMethod(_, mut constraint, assumed) = ident.impl_kind { + constraint.apply_bindings(&bindings); + if assumed { + let trait_impl = TraitImplKind::Assumed { + object_type: constraint.typ, + trait_generics: constraint.trait_generics, + }; + self.interner.select_impl_for_expression(expr_id, trait_impl); + } else { + // Currently only one impl can be selected per expr_id, so this + // constraint needs to be pushed after any other constraints so + // that monomorphization can resolve this trait method to the correct impl. + self.trait_constraints.push((constraint, expr_id)); + } + } + + self.interner.store_instantiation_bindings(expr_id, bindings); + typ + } + + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { + let location = Location::new(path.span(), self.file); + + let error = match path.as_ident().map(|ident| self.use_variable(ident)) { + Some(Ok(found)) => return found, + // Try to look it up as a global, but still issue the first error if we fail + Some(Err(error)) => match self.lookup_global(path) { + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), + Err(_) => error, + }, + None => match self.lookup_global(path) { + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), + Err(error) => error, + }, + }; + self.push_err(error); + let id = DefinitionId::dummy_id(); + (HirIdent::non_trait_method(id, location), 0) + } +} diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs new file mode 100644 index 00000000000..e8baba6bd89 --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -0,0 +1,100 @@ +use rustc_hash::FxHashMap as HashMap; + +use crate::hir::comptime::Value; +use crate::{ + hir::{ + def_map::{ModuleDefId, TryFromModuleDefId}, + resolution::errors::ResolverError, + }, + hir_def::{ + expr::{HirCapturedVar, HirIdent}, + traits::Trait, + }, + macros_api::{Path, StructId}, + node_interner::{DefinitionId, TraitId, TypeAliasId}, + Shared, StructType, +}; + +use super::Elaborator; + +impl Elaborator { + pub(super) fn lookup(&mut self, path: Path) -> Result { + let span = path.span(); + let id = self.resolve_path(path)?; + T::try_from(id).ok_or_else(|| ResolverError::Expected { + expected: T::description(), + got: id.as_str().to_owned(), + span, + }) + } + + pub(super) fn resolve_path(&mut self, path: Path) -> Result { + let path_resolution = self.path_resolver.resolve(&self.def_maps, path)?; + + if let Some(error) = path_resolution.error { + self.push_err(error); + } + + Ok(path_resolution.module_def_id) + } + + pub(super) fn get_struct(&self, type_id: StructId) -> Shared { + self.interner.get_struct(type_id) + } + + pub(super) fn get_trait_mut(&mut self, trait_id: TraitId) -> &mut Trait { + self.interner.get_trait_mut(trait_id) + } + + pub(super) fn resolve_local_variable(&mut self, hir_ident: HirIdent, var_scope_index: usize) { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let position = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if position.is_none() { + self.lambda_stack[lambda_index].captures.push(HirCapturedVar { + ident: hir_ident.clone(), + transitive_capture_index, + }); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(position.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )); + } + } + } + } + + pub(super) fn lookup_global(&mut self, path: Path) -> Result { + let span = path.span(); + let id = self.resolve_path(path)?; + + if let Some(function) = TryFromModuleDefId::try_from(id) { + return Ok(self.interner.function_definition_id(function)); + } + + if let Some(global) = TryFromModuleDefId::try_from(id) { + let global = self.interner.get_global(global); + return Ok(global.definition_id); + } + + let expected = "global variable".into(); + let got = "local variable".into(); + Err(ResolverError::Expected { span, expected, got }) + } +} diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs new file mode 100644 index 00000000000..6ad6d66b2d4 --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -0,0 +1,409 @@ +use noirc_errors::{Location, Span}; + +use crate::{ + ast::{AssignStatement, ConstrainStatement, LValue}, + hir::{ + resolution::errors::ResolverError, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::HirIdent, + stmt::{ + HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, + }, + }, + macros_api::{ + ForLoopStatement, ForRange, HirStatement, LetStatement, Statement, StatementKind, + }, + node_interner::{DefinitionId, DefinitionKind, StmtId}, + Type, +}; + +use super::Elaborator; + +impl Elaborator { + fn elaborate_statement_value(&mut self, statement: Statement) -> (HirStatement, Type) { + match statement.kind { + StatementKind::Let(let_stmt) => self.elaborate_let(let_stmt), + StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain), + StatementKind::Assign(assign) => self.elaborate_assign(assign), + StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), + StatementKind::Break => self.elaborate_jump(true, statement.span), + StatementKind::Continue => self.elaborate_jump(false, statement.span), + StatementKind::Comptime(statement) => self.elaborate_comptime(*statement), + StatementKind::Expression(expr) => { + let (expr, typ) = self.elaborate_expression(expr); + (HirStatement::Expression(expr), typ) + } + StatementKind::Semi(expr) => { + let (expr, _typ) = self.elaborate_expression(expr); + (HirStatement::Semi(expr), Type::Unit) + } + StatementKind::Error => (HirStatement::Error, Type::Error), + } + } + + pub(super) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { + let span = statement.span; + let (hir_statement, typ) = self.elaborate_statement_value(statement); + let id = self.interner.push_stmt(hir_statement); + self.interner.push_stmt_location(id, span, self.file); + (id, typ) + } + + pub(super) fn elaborate_let(&mut self, let_stmt: LetStatement) -> (HirStatement, Type) { + let expr_span = let_stmt.expression.span; + let (expression, expr_type) = self.elaborate_expression(let_stmt.expression); + let definition = DefinitionKind::Local(Some(expression)); + let annotated_type = self.resolve_type(let_stmt.r#type); + + // First check if the LHS is unspecified + // If so, then we give it the same type as the expression + let r#type = if annotated_type != Type::Error { + // Now check if LHS is the same type as the RHS + // Importantly, we do not coerce any types implicitly + self.unify_with_coercions(&expr_type, &annotated_type, expression, || { + TypeCheckError::TypeMismatch { + expected_typ: annotated_type.to_string(), + expr_typ: expr_type.to_string(), + expr_span, + } + }); + if annotated_type.is_unsigned() { + self.lint_overflowing_uint(&expression, &annotated_type); + } + annotated_type + } else { + expr_type + }; + + let let_ = HirLetStatement { + pattern: self.elaborate_pattern(let_stmt.pattern, r#type.clone(), definition), + r#type, + expression, + attributes: let_stmt.attributes, + comptime: let_stmt.comptime, + }; + (HirStatement::Let(let_), Type::Unit) + } + + pub(super) fn elaborate_constrain(&mut self, stmt: ConstrainStatement) -> (HirStatement, Type) { + let expr_span = stmt.0.span; + let (expr_id, expr_type) = self.elaborate_expression(stmt.0); + + // Must type check the assertion message expression so that we instantiate bindings + let msg = stmt.1.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); + + self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expr_typ: expr_type.to_string(), + expected_typ: Type::Bool.to_string(), + expr_span, + }); + + (HirStatement::Constrain(HirConstrainStatement(expr_id, self.file, msg)), Type::Unit) + } + + pub(super) fn elaborate_assign(&mut self, assign: AssignStatement) -> (HirStatement, Type) { + let span = assign.expression.span; + let (expression, expr_type) = self.elaborate_expression(assign.expression); + let (lvalue, lvalue_type, mutable) = self.elaborate_lvalue(assign.lvalue, span); + + if !mutable { + let (name, span) = self.get_lvalue_name_and_span(&lvalue); + self.push_err(TypeCheckError::VariableMustBeMutable { name, span }); + } + + self.unify_with_coercions(&expr_type, &lvalue_type, expression, || { + TypeCheckError::TypeMismatchWithSource { + actual: expr_type.clone(), + expected: lvalue_type.clone(), + span, + source: Source::Assignment, + } + }); + + let stmt = HirAssignStatement { lvalue, expression }; + (HirStatement::Assign(stmt), Type::Unit) + } + + pub(super) fn elaborate_for(&mut self, for_loop: ForLoopStatement) -> (HirStatement, Type) { + let (start, end) = match for_loop.range { + ForRange::Range(start, end) => (start, end), + ForRange::Array(_) => { + let for_stmt = + for_loop.range.into_for(for_loop.identifier, for_loop.block, for_loop.span); + + return self.elaborate_statement_value(for_stmt); + } + }; + + let start_span = start.span; + let end_span = end.span; + + let (start_range, start_range_type) = self.elaborate_expression(start); + let (end_range, end_range_type) = self.elaborate_expression(end); + let (identifier, block) = (for_loop.identifier, for_loop.block); + + self.nested_loops += 1; + self.push_scope(); + + // TODO: For loop variables are currently mutable by default since we haven't + // yet implemented syntax for them to be optionally mutable. + let kind = DefinitionKind::Local(None); + let identifier = self.add_variable_decl(identifier, false, true, kind); + + // Check that start range and end range have the same types + let range_span = start_span.merge(end_span); + self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch { + expected_typ: start_range_type.to_string(), + expr_typ: end_range_type.to_string(), + expr_span: range_span, + }); + + let expected_type = self.polymorphic_integer(); + + self.unify(&start_range_type, &expected_type, || TypeCheckError::TypeCannotBeUsed { + typ: start_range_type.clone(), + place: "for loop", + span: range_span, + }); + + self.interner.push_definition_type(identifier.id, start_range_type); + + let (block, _block_type) = self.elaborate_expression(block); + + self.pop_scope(); + self.nested_loops -= 1; + + let statement = + HirStatement::For(HirForStatement { start_range, end_range, block, identifier }); + + (statement, Type::Unit) + } + + fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) { + if !self.in_unconstrained_fn { + self.push_err(ResolverError::JumpInConstrainedFn { is_break, span }); + } + if self.nested_loops == 0 { + self.push_err(ResolverError::JumpOutsideLoop { is_break, span }); + } + + let expr = if is_break { HirStatement::Break } else { HirStatement::Continue }; + (expr, self.interner.next_type_variable()) + } + + fn get_lvalue_name_and_span(&self, lvalue: &HirLValue) -> (String, Span) { + match lvalue { + HirLValue::Ident(name, _) => { + let span = name.location.span; + + if let Some(definition) = self.interner.try_definition(name.id) { + (definition.name.clone(), span) + } else { + ("(undeclared variable)".into(), span) + } + } + HirLValue::MemberAccess { object, .. } => self.get_lvalue_name_and_span(object), + HirLValue::Index { array, .. } => self.get_lvalue_name_and_span(array), + HirLValue::Dereference { lvalue, .. } => self.get_lvalue_name_and_span(lvalue), + } + } + + fn elaborate_lvalue(&mut self, lvalue: LValue, assign_span: Span) -> (HirLValue, Type, bool) { + match lvalue { + LValue::Ident(ident) => { + let mut mutable = true; + let (ident, scope_index) = self.find_variable_or_default(&ident); + self.resolve_local_variable(ident.clone(), scope_index); + + let typ = if ident.id == DefinitionId::dummy_id() { + Type::Error + } else { + if let Some(definition) = self.interner.try_definition(ident.id) { + mutable = definition.mutable; + } + + let typ = self.interner.definition_type(ident.id).instantiate(&self.interner).0; + typ.follow_bindings() + }; + + (HirLValue::Ident(ident.clone(), typ.clone()), typ, mutable) + } + LValue::MemberAccess { object, field_name, span } => { + let (object, lhs_type, mut mutable) = self.elaborate_lvalue(*object, assign_span); + let mut object = Box::new(object); + let field_name = field_name.clone(); + + let object_ref = &mut object; + let mutable_ref = &mut mutable; + let location = Location::new(span, self.file); + + let dereference_lhs = move |_: &mut Self, _, element_type| { + // We must create a temporary value first to move out of object_ref before + // we eventually reassign to it. + let id = DefinitionId::dummy_id(); + let ident = HirIdent::non_trait_method(id, location); + let tmp_value = HirLValue::Ident(ident, Type::Error); + + let lvalue = std::mem::replace(object_ref, Box::new(tmp_value)); + *object_ref = + Box::new(HirLValue::Dereference { lvalue, element_type, location }); + *mutable_ref = true; + }; + + let name = &field_name.0.contents; + let (object_type, field_index) = self + .check_field_access(&lhs_type, name, field_name.span(), Some(dereference_lhs)) + .unwrap_or((Type::Error, 0)); + + let field_index = Some(field_index); + let typ = object_type.clone(); + let lvalue = + HirLValue::MemberAccess { object, field_name, field_index, typ, location }; + (lvalue, object_type, mutable) + } + LValue::Index { array, index, span } => { + let expr_span = index.span; + let (index, index_type) = self.elaborate_expression(index); + let location = Location::new(span, self.file); + + let expected = self.polymorphic_integer_or_field(); + self.unify(&index_type, &expected, || TypeCheckError::TypeMismatch { + expected_typ: "an integer".to_owned(), + expr_typ: index_type.to_string(), + expr_span, + }); + + let (mut lvalue, mut lvalue_type, mut mutable) = + self.elaborate_lvalue(*array, assign_span); + + // Before we check that the lvalue is an array, try to dereference it as many times + // as needed to unwrap any &mut wrappers. + while let Type::MutableReference(element) = lvalue_type.follow_bindings() { + let element_type = element.as_ref().clone(); + lvalue = + HirLValue::Dereference { lvalue: Box::new(lvalue), element_type, location }; + lvalue_type = *element; + // We know this value to be mutable now since we found an `&mut` + mutable = true; + } + + let typ = match lvalue_type.follow_bindings() { + Type::Array(_, elem_type) => *elem_type, + Type::Slice(elem_type) => *elem_type, + Type::Error => Type::Error, + Type::String(_) => { + let (_lvalue_name, lvalue_span) = self.get_lvalue_name_and_span(&lvalue); + self.push_err(TypeCheckError::StringIndexAssign { span: lvalue_span }); + Type::Error + } + other => { + // TODO: Need a better span here + self.push_err(TypeCheckError::TypeMismatch { + expected_typ: "array".to_string(), + expr_typ: other.to_string(), + expr_span: assign_span, + }); + Type::Error + } + }; + + let array = Box::new(lvalue); + let array_type = typ.clone(); + (HirLValue::Index { array, index, typ, location }, array_type, mutable) + } + LValue::Dereference(lvalue, span) => { + let (lvalue, reference_type, _) = self.elaborate_lvalue(*lvalue, assign_span); + let lvalue = Box::new(lvalue); + let location = Location::new(span, self.file); + + let element_type = Type::type_variable(self.interner.next_type_variable_id()); + let expected_type = Type::MutableReference(Box::new(element_type.clone())); + + self.unify(&reference_type, &expected_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: reference_type.to_string(), + expr_span: assign_span, + }); + + // Dereferences are always mutable since we already type checked against a &mut T + let typ = element_type.clone(); + let lvalue = HirLValue::Dereference { lvalue, element_type, location }; + (lvalue, typ, true) + } + } + } + + /// Type checks a field access, adding dereference operators as necessary + pub(super) fn check_field_access( + &mut self, + lhs_type: &Type, + field_name: &str, + span: Span, + dereference_lhs: Option, + ) -> Option<(Type, usize)> { + let lhs_type = lhs_type.follow_bindings(); + + match &lhs_type { + Type::Struct(s, args) => { + let s = s.borrow(); + if let Some((field, index)) = s.get_field(field_name, args) { + return Some((field, index)); + } + } + Type::Tuple(elements) => { + if let Ok(index) = field_name.parse::() { + let length = elements.len(); + if index < length { + return Some((elements[index].clone(), index)); + } else { + self.push_err(TypeCheckError::TupleIndexOutOfBounds { + index, + lhs_type, + length, + span, + }); + return None; + } + } + } + // If the lhs is a mutable reference we automatically transform + // lhs.field into (*lhs).field + Type::MutableReference(element) => { + if let Some(mut dereference_lhs) = dereference_lhs { + dereference_lhs(self, lhs_type.clone(), element.as_ref().clone()); + return self.check_field_access( + element, + field_name, + span, + Some(dereference_lhs), + ); + } else { + let (element, index) = + self.check_field_access(element, field_name, span, dereference_lhs)?; + return Some((Type::MutableReference(Box::new(element)), index)); + } + } + _ => (), + } + + // If we get here the type has no field named 'access.rhs'. + // Now we specialize the error message based on whether we know the object type in question yet. + if let Type::TypeVariable(..) = &lhs_type { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + } else if lhs_type != Type::Error { + self.push_err(TypeCheckError::AccessUnknownMember { + lhs_type, + field_name: field_name.to_string(), + span, + }); + } + + None + } + + pub(super) fn elaborate_comptime(&self, _statement: Statement) -> (HirStatement, Type) { + todo!("Comptime scanning") + } +} diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs new file mode 100644 index 00000000000..2cd699b919d --- /dev/null +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -0,0 +1,1399 @@ +use std::rc::Rc; + +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; + +use crate::{ + ast::{BinaryOpKind, IntegerBitSize, UnresolvedTraitConstraint, UnresolvedTypeExpression}, + hir::{ + def_map::ModuleDefId, + resolution::{ + errors::ResolverError, + import::PathResolution, + resolver::{verify_mutable_reference, SELF_TYPE_NAME}, + }, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::{ + HirBinaryOp, HirCallExpression, HirIdent, HirMemberAccess, HirMethodReference, + HirPrefixExpression, + }, + traits::{Trait, TraitConstraint}, + }, + macros_api::{ + HirExpression, HirLiteral, Path, PathKind, SecondaryAttribute, Signedness, UnaryOp, + UnresolvedType, UnresolvedTypeData, + }, + node_interner::{DefinitionKind, ExprId, GlobalId, TraitImplKind, TraitMethodId}, + Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable, TypeVariableKind, +}; + +use super::Elaborator; + +impl Elaborator { + /// Translates an UnresolvedType to a Type + pub(super) fn resolve_type(&mut self, typ: UnresolvedType) -> Type { + let span = typ.span; + let resolved_type = self.resolve_type_inner(typ, &mut vec![]); + if resolved_type.is_nested_slice() { + self.push_err(ResolverError::NestedSlices { span: span.unwrap() }); + } + + resolved_type + } + + /// Translates an UnresolvedType into a Type and appends any + /// freshly created TypeVariables created to new_variables. + fn resolve_type_inner(&mut self, typ: UnresolvedType, new_variables: &mut Generics) -> Type { + use crate::ast::UnresolvedTypeData::*; + + let resolved_type = match typ.typ { + FieldElement => Type::FieldElement, + Array(size, elem) => { + let elem = Box::new(self.resolve_type_inner(*elem, new_variables)); + let size = self.resolve_array_size(Some(size), new_variables); + Type::Array(Box::new(size), elem) + } + Slice(elem) => { + let elem = Box::new(self.resolve_type_inner(*elem, new_variables)); + Type::Slice(elem) + } + Expression(expr) => self.convert_expression_type(expr), + Integer(sign, bits) => Type::Integer(sign, bits), + Bool => Type::Bool, + String(size) => { + let resolved_size = self.resolve_array_size(size, new_variables); + Type::String(Box::new(resolved_size)) + } + FormatString(size, fields) => { + let resolved_size = self.convert_expression_type(size); + let fields = self.resolve_type_inner(*fields, new_variables); + Type::FmtString(Box::new(resolved_size), Box::new(fields)) + } + Code => Type::Code, + Unit => Type::Unit, + Unspecified => Type::Error, + Error => Type::Error, + Named(path, args, _) => self.resolve_named_type(path, args, new_variables), + TraitAsType(path, args) => self.resolve_trait_as_type(path, args, new_variables), + + Tuple(fields) => { + Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables))) + } + Function(args, ret, env) => { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); + + // expect() here is valid, because the only places we don't have a span are omitted types + // e.g. a function without return type implicitly has a spanless UnresolvedType::Unit return type + // To get an invalid env type, the user must explicitly specify the type, which will have a span + let env_span = + env.span.expect("Unexpected missing span for closure environment type"); + + let env = Box::new(self.resolve_type_inner(*env, new_variables)); + + match *env { + Type::Unit | Type::Tuple(_) | Type::NamedGeneric(_, _) => { + Type::Function(args, ret, env) + } + _ => { + self.push_err(ResolverError::InvalidClosureEnvironment { + typ: *env, + span: env_span, + }); + Type::Error + } + } + } + MutableReference(element) => { + Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) + } + Parenthesized(typ) => self.resolve_type_inner(*typ, new_variables), + }; + + if let Type::Struct(_, _) = resolved_type { + if let Some(unresolved_span) = typ.span { + // Record the location of the type reference + self.interner.push_type_ref_location( + resolved_type.clone(), + Location::new(unresolved_span, self.file), + ); + } + } + resolved_type + } + + fn find_generic(&self, target_name: &str) -> Option<&(Rc, TypeVariable, Span)> { + self.generics.iter().find(|(name, _, _)| name.as_ref() == target_name) + } + + fn resolve_named_type( + &mut self, + path: Path, + args: Vec, + new_variables: &mut Generics, + ) -> Type { + if args.is_empty() { + if let Some(typ) = self.lookup_generic_or_global_type(&path) { + return typ; + } + } + + // Check if the path is a type variable first. We currently disallow generics on type + // variables since we do not support higher-kinded types. + if path.segments.len() == 1 { + let name = &path.last_segment().0.contents; + + if name == SELF_TYPE_NAME { + if let Some(self_type) = self.self_type.clone() { + if !args.is_empty() { + self.push_err(ResolverError::GenericsOnSelfType { span: path.span() }); + } + return self_type; + } + } + } + + let span = path.span(); + let mut args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + + if let Some(type_alias) = self.lookup_type_alias(path.clone()) { + let type_alias = type_alias.borrow(); + let expected_generic_count = type_alias.generics.len(); + let type_alias_string = type_alias.to_string(); + let id = type_alias.id; + + self.verify_generics_count(expected_generic_count, &mut args, span, || { + type_alias_string + }); + + if let Some(item) = self.current_item { + self.interner.add_type_alias_dependency(item, id); + } + + // Collecting Type Alias references [Location]s to be used by LSP in order + // to resolve the definition of the type alias + self.interner.add_type_alias_ref(id, Location::new(span, self.file)); + + // Because there is no ordering to when type aliases (and other globals) are resolved, + // it is possible for one to refer to an Error type and issue no error if it is set + // equal to another type alias. Fixing this fully requires an analysis to create a DFG + // of definition ordering, but for now we have an explicit check here so that we at + // least issue an error that the type was not found instead of silently passing. + let alias = self.interner.get_type_alias(id); + return Type::Alias(alias, args); + } + + match self.lookup_struct_or_error(path) { + Some(struct_type) => { + if self.resolving_ids.contains(&struct_type.borrow().id) { + self.push_err(ResolverError::SelfReferentialStruct { + span: struct_type.borrow().name.span(), + }); + + return Type::Error; + } + + let expected_generic_count = struct_type.borrow().generics.len(); + if !self.in_contract + && self + .interner + .struct_attributes(&struct_type.borrow().id) + .iter() + .any(|attr| matches!(attr, SecondaryAttribute::Abi(_))) + { + self.push_err(ResolverError::AbiAttributeOutsideContract { + span: struct_type.borrow().name.span(), + }); + } + self.verify_generics_count(expected_generic_count, &mut args, span, || { + struct_type.borrow().to_string() + }); + + if let Some(current_item) = self.current_item { + let dependency_id = struct_type.borrow().id; + self.interner.add_type_dependency(current_item, dependency_id); + } + + Type::Struct(struct_type, args) + } + None => Type::Error, + } + } + + fn resolve_trait_as_type( + &mut self, + path: Path, + args: Vec, + new_variables: &mut Generics, + ) -> Type { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + + if let Some(t) = self.lookup_trait_or_error(path) { + Type::TraitAsType(t.id, Rc::new(t.name.to_string()), args) + } else { + Type::Error + } + } + + fn verify_generics_count( + &mut self, + expected_count: usize, + args: &mut Vec, + span: Span, + type_name: impl FnOnce() -> String, + ) { + if args.len() != expected_count { + self.push_err(ResolverError::IncorrectGenericCount { + span, + item_name: type_name(), + actual: args.len(), + expected: expected_count, + }); + + // Fix the generic count so we can continue typechecking + args.resize_with(expected_count, || Type::Error); + } + } + + fn lookup_generic_or_global_type(&mut self, path: &Path) -> Option { + if path.segments.len() == 1 { + let name = &path.last_segment().0.contents; + if let Some((name, var, _)) = self.find_generic(name) { + return Some(Type::NamedGeneric(var.clone(), name.clone())); + } + } + + // If we cannot find a local generic of the same name, try to look up a global + match self.path_resolver.resolve(&self.def_maps, path.clone()) { + Ok(PathResolution { module_def_id: ModuleDefId::GlobalId(id), error }) => { + if let Some(current_item) = self.current_item { + self.interner.add_global_dependency(current_item, id); + } + + if let Some(error) = error { + self.push_err(error); + } + Some(Type::Constant(self.eval_global_as_array_length(id, path))) + } + _ => None, + } + } + + fn resolve_array_size( + &mut self, + length: Option, + new_variables: &mut Generics, + ) -> Type { + match length { + None => { + let id = self.interner.next_type_variable_id(); + let typevar = TypeVariable::unbound(id); + new_variables.push(typevar.clone()); + + // 'Named'Generic is a bit of a misnomer here, we want a type variable that + // wont be bound over but this one has no name since we do not currently + // require users to explicitly be generic over array lengths. + Type::NamedGeneric(typevar, Rc::new("".into())) + } + Some(length) => self.convert_expression_type(length), + } + } + + pub(super) fn convert_expression_type(&mut self, length: UnresolvedTypeExpression) -> Type { + match length { + UnresolvedTypeExpression::Variable(path) => { + self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { + self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); + Type::Constant(0) + }) + } + UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); + let lhs = self.convert_expression_type(*lhs); + let rhs = self.convert_expression_type(*rhs); + + match (lhs, rhs) { + (Type::Constant(lhs), Type::Constant(rhs)) => { + Type::Constant(op.function()(lhs, rhs)) + } + (lhs, _) => { + let span = + if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; + self.push_err(ResolverError::InvalidArrayLengthExpr { span }); + Type::Constant(0) + } + } + } + } + } + + /// Lookup a given struct type by name. + fn lookup_struct_or_error(&mut self, path: Path) -> Option> { + match self.lookup(path) { + Ok(struct_id) => Some(self.get_struct(struct_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + + /// Lookup a given trait by name/path. + fn lookup_trait_or_error(&mut self, path: Path) -> Option<&mut Trait> { + match self.lookup(path) { + Ok(trait_id) => Some(self.get_trait_mut(trait_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + + /// Looks up a given type by name. + /// This will also instantiate any struct types found. + pub(super) fn lookup_type_or_error(&mut self, path: Path) -> Option { + let ident = path.as_ident(); + if ident.map_or(false, |i| i == SELF_TYPE_NAME) { + if let Some(typ) = &self.self_type { + return Some(typ.clone()); + } + } + + match self.lookup(path) { + Ok(struct_id) => { + let struct_type = self.get_struct(struct_id); + let generics = struct_type.borrow().instantiate(&mut self.interner); + Some(Type::Struct(struct_type, generics)) + } + Err(error) => { + self.push_err(error); + None + } + } + } + + fn lookup_type_alias(&mut self, path: Path) -> Option> { + self.lookup(path).ok().map(|id| self.interner.get_type_alias(id)) + } + + // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_static_method_by_self( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + let trait_id = self.trait_id?; + + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let name = &path.segments[0].0.contents; + let method = &path.segments[1]; + + if name == SELF_TYPE_NAME { + let the_trait = self.interner.get_trait(trait_id); + let method = the_trait.find_method(method.0.contents.as_str())?; + + let constraint = TraitConstraint { + typ: self.self_type.clone()?, + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); + } + } + None + } + + // this resolves TraitName::some_static_method + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_static_method( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let method = &path.segments[1]; + + let mut trait_path = path.clone(); + trait_path.pop(); + let trait_id = self.lookup(trait_path).ok()?; + let the_trait = self.interner.get_trait(trait_id); + + let method = the_trait.find_method(method.0.contents.as_str())?; + let constraint = TraitConstraint { + typ: Type::TypeVariable( + the_trait.self_type_typevar.clone(), + TypeVariableKind::Normal, + ), + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); + } + None + } + + // This resolves a static trait method T::trait_method by iterating over the where clause + // + // Returns the trait method, trait constraint, and whether the impl is assumed from a where + // clause. This is always true since this helper searches where clauses for a generic constraint. + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_method_by_named_generic( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + if path.segments.len() != 2 { + return None; + } + + for UnresolvedTraitConstraint { typ, trait_bound } in self.trait_bounds.clone() { + if let UnresolvedTypeData::Named(constraint_path, _, _) = &typ.typ { + // if `path` is `T::method_name`, we're looking for constraint of the form `T: SomeTrait` + if constraint_path.segments.len() == 1 + && path.segments[0] != constraint_path.last_segment() + { + continue; + } + + if let Ok(ModuleDefId::TraitId(trait_id)) = + self.resolve_path(trait_bound.trait_path.clone()) + { + let the_trait = self.interner.get_trait(trait_id); + if let Some(method) = + the_trait.find_method(path.segments.last().unwrap().0.contents.as_str()) + { + let constraint = TraitConstraint { + trait_id, + typ: self.resolve_type(typ.clone()), + trait_generics: vecmap(trait_bound.trait_generics, |typ| { + self.resolve_type(typ) + }), + }; + return Some((method, constraint, true)); + } + } + } + } + None + } + + // Try to resolve the given trait method path. + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + pub(super) fn resolve_trait_generic_path( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + self.resolve_trait_static_method_by_self(path) + .or_else(|| self.resolve_trait_static_method(path)) + .or_else(|| self.resolve_trait_method_by_named_generic(path)) + } + + fn eval_global_as_array_length(&mut self, global: GlobalId, path: &Path) -> u64 { + let Some(stmt) = self.interner.get_global_let_statement(global) else { + let path = path.clone(); + self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); + return 0; + }; + + let length = stmt.expression; + let span = self.interner.expr_span(&length); + let result = self.try_eval_array_length_id(length, span); + + match result.map(|length| length.try_into()) { + Ok(Ok(length_value)) => return length_value, + Ok(Err(_cast_err)) => self.push_err(ResolverError::IntegerTooLarge { span }), + Err(Some(error)) => self.push_err(error), + Err(None) => (), + } + 0 + } + + fn try_eval_array_length_id( + &self, + rhs: ExprId, + span: Span, + ) -> Result> { + // Arbitrary amount of recursive calls to try before giving up + let fuel = 100; + self.try_eval_array_length_id_with_fuel(rhs, span, fuel) + } + + fn try_eval_array_length_id_with_fuel( + &self, + rhs: ExprId, + span: Span, + fuel: u32, + ) -> Result> { + if fuel == 0 { + // If we reach here, it is likely from evaluating cyclic globals. We expect an error to + // be issued for them after name resolution so issue no error now. + return Err(None); + } + + match self.interner.expression(&rhs) { + HirExpression::Literal(HirLiteral::Integer(int, false)) => { + int.try_into_u128().ok_or(Some(ResolverError::IntegerTooLarge { span })) + } + HirExpression::Ident(ident) => { + let definition = self.interner.definition(ident.id); + match definition.kind { + DefinitionKind::Global(global_id) => { + let let_statement = self.interner.get_global_let_statement(global_id); + if let Some(let_statement) = let_statement { + let expression = let_statement.expression; + self.try_eval_array_length_id_with_fuel(expression, span, fuel - 1) + } else { + Err(Some(ResolverError::InvalidArrayLengthExpr { span })) + } + } + _ => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + } + } + HirExpression::Infix(infix) => { + let lhs = self.try_eval_array_length_id_with_fuel(infix.lhs, span, fuel - 1)?; + let rhs = self.try_eval_array_length_id_with_fuel(infix.rhs, span, fuel - 1)?; + + match infix.operator.kind { + BinaryOpKind::Add => Ok(lhs + rhs), + BinaryOpKind::Subtract => Ok(lhs - rhs), + BinaryOpKind::Multiply => Ok(lhs * rhs), + BinaryOpKind::Divide => Ok(lhs / rhs), + BinaryOpKind::Equal => Ok((lhs == rhs) as u128), + BinaryOpKind::NotEqual => Ok((lhs != rhs) as u128), + BinaryOpKind::Less => Ok((lhs < rhs) as u128), + BinaryOpKind::LessEqual => Ok((lhs <= rhs) as u128), + BinaryOpKind::Greater => Ok((lhs > rhs) as u128), + BinaryOpKind::GreaterEqual => Ok((lhs >= rhs) as u128), + BinaryOpKind::And => Ok(lhs & rhs), + BinaryOpKind::Or => Ok(lhs | rhs), + BinaryOpKind::Xor => Ok(lhs ^ rhs), + BinaryOpKind::ShiftRight => Ok(lhs >> rhs), + BinaryOpKind::ShiftLeft => Ok(lhs << rhs), + BinaryOpKind::Modulo => Ok(lhs % rhs), + } + } + _other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + } + } + + /// Check if an assignment is overflowing with respect to `annotated_type` + /// in a declaration statement where `annotated_type` is an unsigned integer + pub(super) fn lint_overflowing_uint(&mut self, rhs_expr: &ExprId, annotated_type: &Type) { + let expr = self.interner.expression(rhs_expr); + let span = self.interner.expr_span(rhs_expr); + match expr { + HirExpression::Literal(HirLiteral::Integer(value, false)) => { + let v = value.to_u128(); + if let Type::Integer(_, bit_count) = annotated_type { + let bit_count: u32 = (*bit_count).into(); + let max = 1 << bit_count; + if v >= max { + self.push_err(TypeCheckError::OverflowingAssignment { + expr: value, + ty: annotated_type.clone(), + range: format!("0..={}", max - 1), + span, + }); + }; + }; + } + HirExpression::Prefix(expr) => { + self.lint_overflowing_uint(&expr.rhs, annotated_type); + if matches!(expr.operator, UnaryOp::Minus) { + self.push_err(TypeCheckError::InvalidUnaryOp { + kind: "annotated_type".to_string(), + span, + }); + } + } + HirExpression::Infix(expr) => { + self.lint_overflowing_uint(&expr.lhs, annotated_type); + self.lint_overflowing_uint(&expr.rhs, annotated_type); + } + _ => {} + } + } + + pub(super) fn unify( + &mut self, + actual: &Type, + expected: &Type, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut errors = Vec::new(); + actual.unify(expected, &mut errors, make_error); + self.errors.extend(errors.into_iter().map(Into::into)); + } + + /// Wrapper of Type::unify_with_coercions using self.errors + pub(super) fn unify_with_coercions( + &mut self, + actual: &Type, + expected: &Type, + expression: ExprId, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut errors = Vec::new(); + actual.unify_with_coercions( + expected, + expression, + &mut self.interner, + &mut errors, + make_error, + ); + self.errors.extend(errors.into_iter().map(Into::into)); + } + + /// Return a fresh integer or field type variable and log it + /// in self.type_variables to default it later. + pub(super) fn polymorphic_integer_or_field(&mut self) -> Type { + let typ = Type::polymorphic_integer_or_field(&mut self.interner); + self.type_variables.push(typ.clone()); + typ + } + + /// Return a fresh integer type variable and log it + /// in self.type_variables to default it later. + pub(super) fn polymorphic_integer(&mut self) -> Type { + let typ = Type::polymorphic_integer(&mut self.interner); + self.type_variables.push(typ.clone()); + typ + } + + /// Translates a (possibly Unspecified) UnresolvedType to a Type. + /// Any UnresolvedType::Unspecified encountered are replaced with fresh type variables. + pub(super) fn resolve_inferred_type(&mut self, typ: UnresolvedType) -> Type { + match &typ.typ { + UnresolvedTypeData::Unspecified => self.interner.next_type_variable(), + _ => self.resolve_type_inner(typ, &mut vec![]), + } + } + + pub(super) fn type_check_prefix_operand( + &mut self, + op: &crate::ast::UnaryOp, + rhs_type: &Type, + span: Span, + ) -> Type { + let mut unify = |this: &mut Self, expected| { + this.unify(rhs_type, &expected, || TypeCheckError::TypeMismatch { + expr_typ: rhs_type.to_string(), + expected_typ: expected.to_string(), + expr_span: span, + }); + expected + }; + + match op { + crate::ast::UnaryOp::Minus => { + if rhs_type.is_unsigned() { + self.push_err(TypeCheckError::InvalidUnaryOp { + kind: rhs_type.to_string(), + span, + }); + } + let expected = self.polymorphic_integer_or_field(); + self.unify(rhs_type, &expected, || TypeCheckError::InvalidUnaryOp { + kind: rhs_type.to_string(), + span, + }); + expected + } + crate::ast::UnaryOp::Not => { + let rhs_type = rhs_type.follow_bindings(); + + // `!` can work on booleans or integers + if matches!(rhs_type, Type::Integer(..)) { + return rhs_type; + } + + unify(self, Type::Bool) + } + crate::ast::UnaryOp::MutableReference => { + Type::MutableReference(Box::new(rhs_type.follow_bindings())) + } + crate::ast::UnaryOp::Dereference { implicitly_added: _ } => { + let element_type = self.interner.next_type_variable(); + unify(self, Type::MutableReference(Box::new(element_type.clone()))); + element_type + } + } + } + + /// Insert as many dereference operations as necessary to automatically dereference a method + /// call object to its base value type T. + pub(super) fn insert_auto_dereferences(&mut self, object: ExprId, typ: Type) -> (ExprId, Type) { + if let Type::MutableReference(element) = typ { + let location = self.interner.id_location(object); + + let object = self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::Dereference { implicitly_added: true }, + rhs: object, + })); + self.interner.push_expr_type(object, element.as_ref().clone()); + self.interner.push_expr_location(object, location.span, location.file); + + // Recursively dereference to allow for converting &mut &mut T to T + self.insert_auto_dereferences(object, *element) + } else { + (object, typ) + } + } + + /// Given a method object: `(*foo).bar` of a method call `(*foo).bar.baz()`, remove the + /// implicitly added dereference operator if one is found. + /// + /// Returns Some(new_expr_id) if a dereference was removed and None otherwise. + fn try_remove_implicit_dereference(&mut self, object: ExprId) -> Option { + match self.interner.expression(&object) { + HirExpression::MemberAccess(mut access) => { + let new_lhs = self.try_remove_implicit_dereference(access.lhs)?; + access.lhs = new_lhs; + access.is_offset = true; + + // `object` will have a different type now, which will be filled in + // later when type checking the method call as a function call. + self.interner.replace_expr(&object, HirExpression::MemberAccess(access)); + Some(object) + } + HirExpression::Prefix(prefix) => match prefix.operator { + // Found a dereference we can remove. Now just replace it with its rhs to remove it. + UnaryOp::Dereference { implicitly_added: true } => Some(prefix.rhs), + _ => None, + }, + _ => None, + } + } + + fn bind_function_type_impl( + &mut self, + fn_params: &[Type], + fn_ret: &Type, + callsite_args: &[(Type, ExprId, Span)], + span: Span, + ) -> Type { + if fn_params.len() != callsite_args.len() { + self.push_err(TypeCheckError::ParameterCountMismatch { + expected: fn_params.len(), + found: callsite_args.len(), + span, + }); + return Type::Error; + } + + for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { + self.unify(arg, param, || TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + }); + } + + fn_ret.clone() + } + + pub(super) fn bind_function_type( + &mut self, + function: Type, + args: Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + // Could do a single unification for the entire function type, but matching beforehand + // lets us issue a more precise error on the individual argument that fails to type check. + match function { + Type::TypeVariable(binding, TypeVariableKind::Normal) => { + if let TypeBinding::Bound(typ) = &*binding.borrow() { + return self.bind_function_type(typ.clone(), args, span); + } + + let ret = self.interner.next_type_variable(); + let args = vecmap(args, |(arg, _, _)| arg); + let env_type = self.interner.next_type_variable(); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); + + if let Err(error) = binding.try_bind(expected, span) { + self.push_err(error); + } + ret + } + // The closure env is ignored on purpose: call arguments never place + // constraints on closure environments. + Type::Function(parameters, ret, _env) => { + self.bind_function_type_impl(¶meters, &ret, &args, span) + } + Type::Error => Type::Error, + found => { + self.push_err(TypeCheckError::ExpectedFunction { found, span }); + Type::Error + } + } + } + + pub(super) fn check_cast(&mut self, from: Type, to: &Type, span: Span) -> Type { + match from.follow_bindings() { + Type::Integer(..) + | Type::FieldElement + | Type::TypeVariable(_, TypeVariableKind::IntegerOrField) + | Type::TypeVariable(_, TypeVariableKind::Integer) + | Type::Bool => (), + + Type::TypeVariable(_, _) => { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + return Type::Error; + } + Type::Error => return Type::Error, + from => { + self.push_err(TypeCheckError::InvalidCast { from, span }); + return Type::Error; + } + } + + match to { + Type::Integer(sign, bits) => Type::Integer(*sign, *bits), + Type::FieldElement => Type::FieldElement, + Type::Bool => Type::Bool, + Type::Error => Type::Error, + _ => { + self.push_err(TypeCheckError::UnsupportedCast { span }); + Type::Error + } + } + } + + // Given a binary comparison operator and another type. This method will produce the output type + // and a boolean indicating whether to use the trait impl corresponding to the operator + // or not. A value of false indicates the caller to use a primitive operation for this + // operator, while a true value indicates a user-provided trait impl is required. + fn comparator_operand_type_rules( + &mut self, + lhs_type: &Type, + rhs_type: &Type, + op: &HirBinaryOp, + span: Span, + ) -> Result<(Type, bool), TypeCheckError> { + use Type::*; + + match (lhs_type, rhs_type) { + // Avoid reporting errors multiple times + (Error, _) | (_, Error) => Ok((Bool, false)), + (Alias(alias, args), other) | (other, Alias(alias, args)) => { + let alias = alias.borrow().get_type(args); + self.comparator_operand_type_rules(&alias, other, op, span) + } + + // Matches on TypeVariable must be first to follow any type + // bindings. + (TypeVariable(var, _), other) | (other, TypeVariable(var, _)) => { + if let TypeBinding::Bound(binding) = &*var.borrow() { + return self.comparator_operand_type_rules(other, binding, op, span); + } + + let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); + Ok((Bool, use_impl)) + } + (Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => { + if sign_x != sign_y { + return Err(TypeCheckError::IntegerSignedness { + sign_x: *sign_x, + sign_y: *sign_y, + span, + }); + } + if bit_width_x != bit_width_y { + return Err(TypeCheckError::IntegerBitWidth { + bit_width_x: *bit_width_x, + bit_width_y: *bit_width_y, + span, + }); + } + Ok((Bool, false)) + } + (FieldElement, FieldElement) => { + if op.kind.is_valid_for_field_type() { + Ok((Bool, false)) + } else { + Err(TypeCheckError::FieldComparison { span }) + } + } + + // <= and friends are technically valid for booleans, just not very useful + (Bool, Bool) => Ok((Bool, false)), + + (lhs, rhs) => { + self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource { + expected: lhs.clone(), + actual: rhs.clone(), + span: op.location.span, + source: Source::Binary, + }); + Ok((Bool, true)) + } + } + } + + /// Handles the TypeVariable case for checking binary operators. + /// Returns true if we should use the impl for the operator instead of the primitive + /// version of it. + fn bind_type_variables_for_infix( + &mut self, + lhs_type: &Type, + op: &HirBinaryOp, + rhs_type: &Type, + span: Span, + ) -> bool { + self.unify(lhs_type, rhs_type, || TypeCheckError::TypeMismatchWithSource { + expected: lhs_type.clone(), + actual: rhs_type.clone(), + source: Source::Binary, + span, + }); + + let use_impl = !lhs_type.is_numeric(); + + // If this operator isn't valid for fields we have to possibly narrow + // TypeVariableKind::IntegerOrField to TypeVariableKind::Integer. + // Doing so also ensures a type error if Field is used. + // The is_numeric check is to allow impls for custom types to bypass this. + if !op.kind.is_valid_for_field_type() && lhs_type.is_numeric() { + let target = Type::polymorphic_integer(&mut self.interner); + + use crate::ast::BinaryOpKind::*; + use TypeCheckError::*; + self.unify(lhs_type, &target, || match op.kind { + Less | LessEqual | Greater | GreaterEqual => FieldComparison { span }, + And | Or | Xor | ShiftRight | ShiftLeft => FieldBitwiseOp { span }, + Modulo => FieldModulo { span }, + other => unreachable!("Operator {other:?} should be valid for Field"), + }); + } + + use_impl + } + + // Given a binary operator and another type. This method will produce the output type + // and a boolean indicating whether to use the trait impl corresponding to the operator + // or not. A value of false indicates the caller to use a primitive operation for this + // operator, while a true value indicates a user-provided trait impl is required. + pub(super) fn infix_operand_type_rules( + &mut self, + lhs_type: &Type, + op: &HirBinaryOp, + rhs_type: &Type, + span: Span, + ) -> Result<(Type, bool), TypeCheckError> { + if op.kind.is_comparator() { + return self.comparator_operand_type_rules(lhs_type, rhs_type, op, span); + } + + use Type::*; + match (lhs_type, rhs_type) { + // An error type on either side will always return an error + (Error, _) | (_, Error) => Ok((Error, false)), + (Alias(alias, args), other) | (other, Alias(alias, args)) => { + let alias = alias.borrow().get_type(args); + self.infix_operand_type_rules(&alias, op, other, span) + } + + // Matches on TypeVariable must be first so that we follow any type + // bindings. + (TypeVariable(int, _), other) | (other, TypeVariable(int, _)) => { + if let TypeBinding::Bound(binding) = &*int.borrow() { + return self.infix_operand_type_rules(binding, op, other, span); + } + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + self.unify( + rhs_type, + &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + || TypeCheckError::InvalidShiftSize { span }, + ); + let use_impl = if lhs_type.is_numeric() { + let integer_type = Type::polymorphic_integer(&mut self.interner); + self.bind_type_variables_for_infix(lhs_type, op, &integer_type, span) + } else { + true + }; + return Ok((lhs_type.clone(), use_impl)); + } + let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); + Ok((other.clone(), use_impl)) + } + (Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => { + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + if *sign_y != Signedness::Unsigned || *bit_width_y != IntegerBitSize::Eight { + return Err(TypeCheckError::InvalidShiftSize { span }); + } + return Ok((Integer(*sign_x, *bit_width_x), false)); + } + if sign_x != sign_y { + return Err(TypeCheckError::IntegerSignedness { + sign_x: *sign_x, + sign_y: *sign_y, + span, + }); + } + if bit_width_x != bit_width_y { + return Err(TypeCheckError::IntegerBitWidth { + bit_width_x: *bit_width_x, + bit_width_y: *bit_width_y, + span, + }); + } + Ok((Integer(*sign_x, *bit_width_x), false)) + } + // The result of two Fields is always a witness + (FieldElement, FieldElement) => { + if !op.kind.is_valid_for_field_type() { + if op.kind == BinaryOpKind::Modulo { + return Err(TypeCheckError::FieldModulo { span }); + } else { + return Err(TypeCheckError::FieldBitwiseOp { span }); + } + } + Ok((FieldElement, false)) + } + + (Bool, Bool) => Ok((Bool, false)), + + (lhs, rhs) => { + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + if rhs == &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight) { + return Ok((lhs.clone(), true)); + } + return Err(TypeCheckError::InvalidShiftSize { span }); + } + self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource { + expected: lhs.clone(), + actual: rhs.clone(), + span: op.location.span, + source: Source::Binary, + }); + Ok((lhs.clone(), true)) + } + } + } + + /// Prerequisite: verify_trait_constraint of the operator's trait constraint. + /// + /// Although by this point the operator is expected to already have a trait impl, + /// we still need to match the operator's type against the method's instantiated type + /// to ensure the instantiation bindings are correct and the monomorphizer can + /// re-apply the needed bindings. + pub(super) fn type_check_operator_method( + &mut self, + expr_id: ExprId, + trait_method_id: TraitMethodId, + object_type: &Type, + span: Span, + ) { + let the_trait = self.interner.get_trait(trait_method_id.trait_id); + + let method = &the_trait.methods[trait_method_id.method_index]; + let (method_type, mut bindings) = method.typ.clone().instantiate(&self.interner); + + match method_type { + Type::Function(args, _, _) => { + // We can cheat a bit and match against only the object type here since no operator + // overload uses other generic parameters or return types aside from the object type. + let expected_object_type = &args[0]; + self.unify(object_type, expected_object_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_object_type.to_string(), + expr_typ: object_type.to_string(), + expr_span: span, + }); + } + other => { + unreachable!("Expected operator method to have a function type, but found {other}") + } + } + + // We must also remember to apply these substitutions to the object_type + // referenced by the selected trait impl, if one has yet to be selected. + let impl_kind = self.interner.get_selected_impl_for_expression(expr_id); + if let Some(TraitImplKind::Assumed { object_type, trait_generics }) = impl_kind { + let the_trait = self.interner.get_trait(trait_method_id.trait_id); + let object_type = object_type.substitute(&bindings); + bindings.insert( + the_trait.self_type_typevar_id, + (the_trait.self_type_typevar.clone(), object_type.clone()), + ); + self.interner.select_impl_for_expression( + expr_id, + TraitImplKind::Assumed { object_type, trait_generics }, + ); + } + + self.interner.store_instantiation_bindings(expr_id, bindings); + } + + pub(super) fn type_check_member_access( + &mut self, + mut access: HirMemberAccess, + expr_id: ExprId, + lhs_type: Type, + span: Span, + ) -> Type { + let access_lhs = &mut access.lhs; + + let dereference_lhs = |this: &mut Self, lhs_type, element| { + let old_lhs = *access_lhs; + *access_lhs = this.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: crate::ast::UnaryOp::Dereference { implicitly_added: true }, + rhs: old_lhs, + })); + this.interner.push_expr_type(old_lhs, lhs_type); + this.interner.push_expr_type(*access_lhs, element); + + let old_location = this.interner.id_location(old_lhs); + this.interner.push_expr_location(*access_lhs, span, old_location.file); + }; + + // If this access is just a field offset, we want to avoid dereferencing + let dereference_lhs = (!access.is_offset).then_some(dereference_lhs); + + match self.check_field_access(&lhs_type, &access.rhs.0.contents, span, dereference_lhs) { + Some((element_type, index)) => { + self.interner.set_field_index(expr_id, index); + // We must update `access` in case we added any dereferences to it + self.interner.replace_expr(&expr_id, HirExpression::MemberAccess(access)); + element_type + } + None => Type::Error, + } + } + + pub(super) fn lookup_method( + &mut self, + object_type: &Type, + method_name: &str, + span: Span, + ) -> Option { + match object_type.follow_bindings() { + Type::Struct(typ, _args) => { + let id = typ.borrow().id; + match self.interner.lookup_method(object_type, id, method_name, false) { + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), + None => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + } + } + // TODO: We should allow method calls on `impl Trait`s eventually. + // For now it is fine since they are only allowed on return types. + Type::TraitAsType(..) => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + Type::NamedGeneric(_, _) => { + let func_meta = self.interner.function_meta( + &self.current_function.expect("unexpected method outside a function"), + ); + + for constraint in &func_meta.trait_constraints { + if *object_type == constraint.typ { + if let Some(the_trait) = self.interner.try_get_trait(constraint.trait_id) { + for (method_index, method) in the_trait.methods.iter().enumerate() { + if method.name.0.contents == method_name { + let trait_method = TraitMethodId { + trait_id: constraint.trait_id, + method_index, + }; + return Some(HirMethodReference::TraitMethodId( + trait_method, + constraint.trait_generics.clone(), + )); + } + } + } + } + } + + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + // Mutable references to another type should resolve to methods of their element type. + // This may be a struct or a primitive type. + Type::MutableReference(element) => self + .interner + .lookup_primitive_trait_method_mut(element.as_ref(), method_name) + .map(HirMethodReference::FuncId) + .or_else(|| self.lookup_method(&element, method_name, span)), + + // If we fail to resolve the object to a struct type, we have no way of type + // checking its arguments as we can't even resolve the name of the function + Type::Error => None, + + // The type variable must be unbound at this point since follow_bindings was called + Type::TypeVariable(_, TypeVariableKind::Normal) => { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + None + } + + other => match self.interner.lookup_primitive_method(&other, method_name) { + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), + None => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + }, + } + } + + pub(super) fn type_check_call( + &mut self, + call: &HirCallExpression, + func_type: Type, + args: Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + // Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression + // These flags are later used to type check calls to unconstrained functions from constrained functions + let func_mod = self.current_function.map(|func| self.interner.function_modifiers(&func)); + let is_current_func_constrained = + func_mod.map_or(true, |func_mod| !func_mod.is_unconstrained); + + let is_unconstrained_call = self.is_unconstrained_call(call.func); + self.check_if_deprecated(call.func); + + // Check that we are not passing a mutable reference from a constrained runtime to an unconstrained runtime + if is_current_func_constrained && is_unconstrained_call { + for (typ, _, _) in args.iter() { + if matches!(&typ.follow_bindings(), Type::MutableReference(_)) { + self.push_err(TypeCheckError::ConstrainedReferenceToUnconstrained { span }); + } + } + } + + let return_type = self.bind_function_type(func_type, args, span); + + // Check that we are not passing a slice from an unconstrained runtime to a constrained runtime + if is_current_func_constrained && is_unconstrained_call { + if return_type.contains_slice() { + self.push_err(TypeCheckError::UnconstrainedSliceReturnToConstrained { span }); + } else if matches!(&return_type.follow_bindings(), Type::MutableReference(_)) { + self.push_err(TypeCheckError::UnconstrainedReferenceToConstrained { span }); + } + }; + + return_type + } + + fn check_if_deprecated(&mut self, expr: ExprId) { + if let HirExpression::Ident(HirIdent { location, id, impl_kind: _ }) = + self.interner.expression(&expr) + { + if let Some(DefinitionKind::Function(func_id)) = + self.interner.try_definition(id).map(|def| &def.kind) + { + let attributes = self.interner.function_attributes(func_id); + if let Some(note) = attributes.get_deprecated_note() { + self.push_err(TypeCheckError::CallDeprecated { + name: self.interner.definition_name(id).to_string(), + note, + span: location.span, + }); + } + } + } + } + + fn is_unconstrained_call(&self, expr: ExprId) -> bool { + if let HirExpression::Ident(HirIdent { id, .. }) = self.interner.expression(&expr) { + if let Some(DefinitionKind::Function(func_id)) = + self.interner.try_definition(id).map(|def| &def.kind) + { + let modifiers = self.interner.function_modifiers(func_id); + return modifiers.is_unconstrained; + } + } + false + } + + /// Check if the given method type requires a mutable reference to the object type, and check + /// if the given object type is already a mutable reference. If not, add one. + /// This is used to automatically transform a method call: `foo.bar()` into a function + /// call: `bar(&mut foo)`. + /// + /// A notable corner case of this function is where it interacts with auto-deref of `.`. + /// If a field is being mutated e.g. `foo.bar.mutate_bar()` where `foo: &mut Foo`, the compiler + /// will insert a dereference before bar `(*foo).bar.mutate_bar()` which would cause us to + /// mutate a copy of bar rather than a reference to it. We must check for this corner case here + /// and remove the implicitly added dereference operator if we find one. + pub(super) fn try_add_mutable_reference_to_object( + &mut self, + function_type: &Type, + object_type: &mut Type, + object: &mut ExprId, + ) { + let expected_object_type = match function_type { + Type::Function(args, _, _) => args.first(), + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _, _) => args.first(), + typ => unreachable!("Unexpected type for function: {typ}"), + }, + typ => unreachable!("Unexpected type for function: {typ}"), + }; + + if let Some(expected_object_type) = expected_object_type { + let actual_type = object_type.follow_bindings(); + + if matches!(expected_object_type.follow_bindings(), Type::MutableReference(_)) { + if !matches!(actual_type, Type::MutableReference(_)) { + if let Err(error) = verify_mutable_reference(&self.interner, *object) { + self.push_err(TypeCheckError::ResolverError(error)); + } + + let new_type = Type::MutableReference(Box::new(actual_type)); + *object_type = new_type.clone(); + + // First try to remove a dereference operator that may have been implicitly + // inserted by a field access expression `foo.bar` on a mutable reference `foo`. + let new_object = self.try_remove_implicit_dereference(*object); + + // If that didn't work, then wrap the whole expression in an `&mut` + *object = new_object.unwrap_or_else(|| { + let location = self.interner.id_location(*object); + + let new_object = + self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::MutableReference, + rhs: *object, + })); + self.interner.push_expr_type(new_object, new_type); + self.interner.push_expr_location(new_object, location.span, location.file); + new_object + }); + } + // Otherwise if the object type is a mutable reference and the method is not, insert as + // many dereferences as needed. + } else if matches!(actual_type, Type::MutableReference(_)) { + let (new_object, new_type) = self.insert_auto_dereferences(*object, actual_type); + *object_type = new_type; + *object = new_object; + } + } + } +} diff --git a/compiler/noirc_frontend/src/hir/resolution/import.rs b/compiler/noirc_frontend/src/hir/resolution/import.rs index 8850331f683..343113836ed 100644 --- a/compiler/noirc_frontend/src/hir/resolution/import.rs +++ b/compiler/noirc_frontend/src/hir/resolution/import.rs @@ -2,11 +2,14 @@ use noirc_errors::{CustomDiagnostic, Span}; use thiserror::Error; use crate::graph::CrateId; +use crate::hir::def_collector::dc_crate::CompilationError; use std::collections::BTreeMap; use crate::ast::{Ident, ItemVisibility, Path, PathKind}; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId, PerNs}; +use super::errors::ResolverError; + #[derive(Debug, Clone)] pub struct ImportDirective { pub module_id: LocalModuleId, @@ -53,6 +56,12 @@ pub struct ResolvedImport { pub error: Option, } +impl From for CompilationError { + fn from(error: PathResolutionError) -> Self { + Self::ResolverError(ResolverError::PathResolutionError(error)) + } +} + impl<'a> From<&'a PathResolutionError> for CustomDiagnostic { fn from(error: &'a PathResolutionError) -> Self { match &error { diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 6c07957b27f..fc4cd4f591c 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -56,17 +56,17 @@ use crate::hir_def::{ use super::errors::{PubPosition, ResolverError}; use super::import::PathResolution; -const SELF_TYPE_NAME: &str = "Self"; +pub const SELF_TYPE_NAME: &str = "Self"; type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; pub struct LambdaContext { - captures: Vec, + pub captures: Vec, /// the index in the scope tree /// (sometimes being filled by ScopeTree's find method) - scope_index: usize, + pub scope_index: usize, } /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 @@ -1346,7 +1346,7 @@ impl<'a> Resolver<'a> { range @ ForRange::Array(_) => { let for_stmt = range.into_for(for_loop.identifier, for_loop.block, for_loop.span); - self.resolve_stmt(for_stmt, for_loop.span) + self.resolve_stmt(for_stmt.kind, for_loop.span) } } } @@ -1362,7 +1362,7 @@ impl<'a> Resolver<'a> { StatementKind::Comptime(statement) => { let hir_statement = self.resolve_stmt(statement.kind, statement.span); let statement_id = self.interner.push_stmt(hir_statement); - self.interner.push_statement_location(statement_id, statement.span, self.file); + self.interner.push_stmt_location(statement_id, statement.span, self.file); HirStatement::Comptime(statement_id) } } @@ -1413,7 +1413,7 @@ impl<'a> Resolver<'a> { pub fn intern_stmt(&mut self, stmt: Statement) -> StmtId { let hir_stmt = self.resolve_stmt(stmt.kind, stmt.span); let id = self.interner.push_stmt(hir_stmt); - self.interner.push_statement_location(id, stmt.span, self.file); + self.interner.push_stmt_location(id, stmt.span, self.file); id } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 9b40c959981..48598109829 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -250,14 +250,14 @@ impl<'interner> TypeChecker<'interner> { } // TODO: update object_type here? - let function_call = method_call.into_function_call( + let (_, function_call) = method_call.into_function_call( &method_ref, object_type, location, self.interner, ); - self.interner.replace_expr(expr_id, function_call); + self.interner.replace_expr(expr_id, HirExpression::Call(function_call)); // Type check the new call now that it has been changed from a method call // to a function call. This way we avoid duplicating code. diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index 0f8131d6ebb..2e448858d9e 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -25,7 +25,7 @@ use crate::{ Type, TypeBindings, }; -use self::errors::Source; +pub use self::errors::Source; pub struct TypeChecker<'interner> { interner: &'interner mut NodeInterner, diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index bf7d9b7b4ba..8df6785e0eb 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -200,13 +200,15 @@ pub enum HirMethodReference { impl HirMethodCallExpression { /// Converts a method call into a function call + /// + /// Returns ((func_var_id, func_var), call_expr) pub fn into_function_call( mut self, method: &HirMethodReference, object_type: Type, location: Location, interner: &mut NodeInterner, - ) -> HirExpression { + ) -> ((ExprId, HirIdent), HirCallExpression) { let mut arguments = vec![self.object]; arguments.append(&mut self.arguments); @@ -224,10 +226,11 @@ impl HirMethodCallExpression { (id, ImplKind::TraitMethod(*method_id, constraint, false)) } }; - let func = HirExpression::Ident(HirIdent { location, id, impl_kind }); - let func = interner.push_expr(func); + let func_var = HirIdent { location, id, impl_kind }; + let func = interner.push_expr(HirExpression::Ident(func_var.clone())); interner.push_expr_location(func, location.span, location.file); - HirExpression::Call(HirCallExpression { func, arguments, location }) + let expr = HirCallExpression { func, arguments, location }; + ((func, func_var), expr) } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index f3b2a24c1f0..f31aeea0552 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1423,14 +1423,14 @@ impl Type { /// Retrieves the type of the given field name /// Panics if the type is not a struct or tuple. - pub fn get_field_type(&self, field_name: &str) -> Type { + pub fn get_field_type(&self, field_name: &str) -> Option { match self { - Type::Struct(def, args) => def.borrow().get_field(field_name, args).unwrap().0, + Type::Struct(def, args) => def.borrow().get_field(field_name, args).map(|(typ, _)| typ), Type::Tuple(fields) => { let mut fields = fields.iter().enumerate(); - fields.find(|(i, _)| i.to_string() == *field_name).unwrap().1.clone() + fields.find(|(i, _)| i.to_string() == *field_name).map(|(_, typ)| typ).cloned() } - other => panic!("Tried to iterate over the fields of '{other}', which has none"), + _ => None, } } diff --git a/compiler/noirc_frontend/src/lib.rs b/compiler/noirc_frontend/src/lib.rs index 958a18ac2fb..b05c635f436 100644 --- a/compiler/noirc_frontend/src/lib.rs +++ b/compiler/noirc_frontend/src/lib.rs @@ -12,6 +12,7 @@ pub mod ast; pub mod debug; +pub mod elaborator; pub mod graph; pub mod lexer; pub mod monomorphization; diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 88adc7a9414..faf89016f96 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -532,7 +532,7 @@ impl NodeInterner { self.id_to_type.insert(expr_id.into(), typ); } - /// Store the type for an interned expression + /// Store the type for a definition pub fn push_definition_type(&mut self, definition_id: DefinitionId, typ: Type) { self.definition_to_type.insert(definition_id, typ); } @@ -696,7 +696,7 @@ impl NodeInterner { let statement = self.push_stmt(HirStatement::Error); let span = name.span(); let id = self.push_global(name, local_id, statement, file, attributes, mutable); - self.push_statement_location(statement, span, file); + self.push_stmt_location(statement, span, file); id } @@ -942,7 +942,7 @@ impl NodeInterner { self.id_location(stmt_id) } - pub fn push_statement_location(&mut self, id: StmtId, span: Span, file: FileId) { + pub fn push_stmt_location(&mut self, id: StmtId, span: Span, file: FileId) { self.id_to_location.insert(id.into(), Location::new(span, file)); }