Skip to content

Commit

Permalink
feat: Implement trait dispatch in the comptime interpreter (#5376)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4925

## Summary\*

Implements trait dispatch in the interpreter - which includes operator
overloading.

## Additional Context

I've tried to share as much code as I can between the interpreter and
monomorphization.

Most of this PR is rather straightforward adapting code in the
monomorphizer to use in the interpreter. The main new bit of code is a
stack of `FunctionContext`s in the elaborator. I found at that when
having a comptime block in the middle of a function, even a known trait
would panic that it had no impl. This was because we previously delayed
solving impls to the very end of a function when types were known. I've
had to change this into a stack instead so that we can solve impls that
were done within a comptime block at the end of that comptime block,
just before interpreting so that they'd be defined. Similarly, the
`type_variables` list also needed to be placed here since defaulting
types can cause some trait constraints to succeed/fail without it. I
settled on `function_context` for the name of this stack but am not sure
if it fits the case where we have a stack of comptime contexts instead.
Similarly, `ComptimeContext` also didn't make sense for the more common
case of only having one of them for a function with no comptime blocks.

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Jul 2, 2024
1 parent 3d86dc6 commit 8aa5b2e
Show file tree
Hide file tree
Showing 15 changed files with 419 additions and 219 deletions.
9 changes: 8 additions & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ impl<'context> Elaborator<'context> {
trait_id: trait_id.trait_id,
trait_generics: Vec::new(),
};
self.trait_constraints.push((constraint, expr_id));
self.push_trait_constraint(constraint, expr_id);
self.type_check_operator_method(expr_id, trait_id, &lhs_type, span);
}
typ
Expand Down Expand Up @@ -663,7 +663,14 @@ impl<'context> Elaborator<'context> {
}

fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) {
// We have to push a new FunctionContext so that we can resolve any constraints
// in this comptime block early before the function as a whole finishes elaborating.
// Otherwise the interpreter below may find expressions for which the underlying trait
// call is not yet solved for.
self.function_context.push(Default::default());
let (block, _typ) = self.elaborate_block_expression(block);
self.check_and_pop_function_context();

let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);
let value = interpreter.evaluate_block(block);
Expand Down
101 changes: 58 additions & 43 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,14 @@ pub struct Elaborator<'context> {

current_function: Option<FuncId>,

/// All type variables created in the current function.
/// This map is used to default any integer type variables at the end of
/// a function (before checking trait constraints) if a type wasn't already chosen.
type_variables: Vec<Type>,

/// Trait constraints are collected during type checking until they are
/// verified at the end of a function. This is because constraints arise
/// on each variable, but it is only until function calls when the types
/// needed for the trait constraint may become known.
trait_constraints: Vec<(TraitConstraint, ExprId)>,
/// This is a stack of function contexts. Most of the time, for each function we
/// expect this to be of length one, containing each type variable and trait constraint
/// used in the function. This is also pushed to when a `comptime {}` block is used within
/// the function. Since it can force us to resolve that block's trait constraints earlier
/// so that they are resolved when the interpreter is run before the enclosing function
/// is finished elaborating. When this happens, we need to resolve any type variables
/// that were made within this block as well so that we can solve these traits.
function_context: Vec<FunctionContext>,

/// The current module this elaborator is in.
/// Initially empty, it is set whenever a new top-level item is resolved.
Expand All @@ -166,6 +164,20 @@ pub struct Elaborator<'context> {
unresolved_globals: BTreeMap<GlobalId, UnresolvedGlobal>,
}

#[derive(Default)]
struct FunctionContext {
/// All type variables created in the current function.
/// This map is used to default any integer type variables at the end of
/// a function (before checking trait constraints) if a type wasn't already chosen.
type_variables: Vec<Type>,

/// Trait constraints are collected during type checking until they are
/// verified at the end of a function. This is because constraints arise
/// on each variable, but it is only until function calls when the types
/// needed for the trait constraint may become known.
trait_constraints: Vec<(TraitConstraint, ExprId)>,
}

impl<'context> Elaborator<'context> {
pub fn new(context: &'context mut Context, crate_id: CrateId) -> Self {
Self {
Expand All @@ -185,8 +197,7 @@ impl<'context> Elaborator<'context> {
resolving_ids: BTreeSet::new(),
trait_bounds: Vec::new(),
current_function: None,
type_variables: Vec::new(),
trait_constraints: Vec::new(),
function_context: vec![FunctionContext::default()],
current_trait_impl: None,
comptime_scopes: vec![HashMap::default()],
unresolved_globals: BTreeMap::new(),
Expand Down Expand Up @@ -326,6 +337,7 @@ impl<'context> Elaborator<'context> {
let func_meta = func_meta.clone();

self.trait_bounds = func_meta.trait_constraints.clone();
self.function_context.push(FunctionContext::default());

// Introduce all numeric generics into scope
for generic in &func_meta.all_generics {
Expand Down Expand Up @@ -367,34 +379,11 @@ impl<'context> Elaborator<'context> {
self.type_check_function_body(body_type, &func_meta, hir_func.as_expr());
}

// Default any type variables that still need defaulting.
// Default any type variables that still need defaulting and
// verify any remaining trait constraints arising from the function body.
// This is done before trait impl search since leaving them bindable can lead to errors
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
for typ in &self.type_variables {
if let Type::TypeVariable(variable, kind) = typ.follow_bindings() {
let msg = "TypeChecker should only track defaultable type vars";
variable.bind(kind.default_type().expect(msg));
}
}

// Verify any remaining trait constraints arising from the function body
for (mut constraint, expr_id) in std::mem::take(&mut self.trait_constraints) {
let span = self.interner.expr_span(&expr_id);

if matches!(&constraint.typ, Type::MutableReference(_)) {
let (_, dereferenced_typ) =
self.insert_auto_dereferences(expr_id, constraint.typ.clone());
constraint.typ = dereferenced_typ;
}

self.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}
self.check_and_pop_function_context();

// Now remove all the `where` clause constraints we added
for constraint in &func_meta.trait_constraints {
Expand All @@ -417,12 +406,42 @@ impl<'context> Elaborator<'context> {
meta.function_body = FunctionBody::Resolved;

self.trait_bounds.clear();
self.type_variables.clear();
self.interner.update_fn(id, hir_func);
self.current_function = old_function;
self.current_item = old_item;
}

/// Defaults all type variables used in this function context then solves
/// all still-unsolved trait constraints in this context.
fn check_and_pop_function_context(&mut self) {
let context = self.function_context.pop().expect("Imbalanced function_context pushes");

for typ in context.type_variables {
if let Type::TypeVariable(variable, kind) = typ.follow_bindings() {
let msg = "TypeChecker should only track defaultable type vars";
variable.bind(kind.default_type().expect(msg));
}
}

for (mut constraint, expr_id) in context.trait_constraints {
let span = self.interner.expr_span(&expr_id);

if matches!(&constraint.typ, Type::MutableReference(_)) {
let (_, dereferenced_typ) =
self.insert_auto_dereferences(expr_id, constraint.typ.clone());
constraint.typ = dereferenced_typ;
}

self.verify_trait_constraint(
&constraint.typ,
constraint.trait_id,
&constraint.trait_generics,
expr_id,
span,
);
}
}

/// This turns function parameters of the form:
/// `fn foo(x: impl Bar)`
///
Expand Down Expand Up @@ -1339,10 +1358,6 @@ impl<'context> Elaborator<'context> {
self.elaborate_comptime_global(global_id);
}

// Avoid defaulting the types of globals here since they may be used in any function.
// Otherwise we may prematurely default to a Field inside the next function if this
// global was unused there, even if it is consistently used as a u8 everywhere else.
self.type_variables.clear();
self.local_module = old_module;
self.file = old_file;
self.current_item = old_item;
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ impl<'context> Elaborator<'context> {
let expr = self.resolve_variable(variable);

let id = self.interner.push_expr(HirExpression::Ident(expr.clone(), generics.clone()));

self.interner.push_expr_location(id, span, self.file);
let typ = self.type_check_variable(expr, id, generics);
self.interner.push_expr_type(id, typ.clone());
Expand Down Expand Up @@ -524,7 +525,7 @@ impl<'context> Elaborator<'context> {

for mut constraint in function.trait_constraints.clone() {
constraint.apply_bindings(&bindings);
self.trait_constraints.push((constraint, expr_id));
self.push_trait_constraint(constraint, expr_id);
}
}
}
Expand All @@ -541,7 +542,7 @@ impl<'context> Elaborator<'context> {
// Currently only one impl can be selected per expr_id, so this
// constraint needs to be pushed after any other constraints so
// that monomorphization can resolve this trait method to the correct impl.
self.trait_constraints.push((constraint, expr_id));
self.push_trait_constraint(constraint, expr_id);
}
}

Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,15 @@ impl<'context> Elaborator<'context> {
}

fn elaborate_comptime_statement(&mut self, statement: Statement) -> (HirStatement, Type) {
// We have to push a new FunctionContext so that we can resolve any constraints
// in this comptime block early before the function as a whole finishes elaborating.
// Otherwise the interpreter below may find expressions for which the underlying trait
// call is not yet solved for.
self.function_context.push(Default::default());
let span = statement.span;
let (hir_statement, _typ) = self.elaborate_statement(statement);
self.check_and_pop_function_context();

let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);
let value = interpreter.evaluate_statement(hir_statement);
Expand Down
46 changes: 23 additions & 23 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
errors::ResolverError,
resolver::{verify_mutable_reference, SELF_TYPE_NAME, WILDCARD_TYPE},
},
type_check::{Source, TypeCheckError},
type_check::{NoMatchingImplFoundError, Source, TypeCheckError},
},
hir_def::{
expr::{
Expand Down Expand Up @@ -615,15 +615,15 @@ impl<'context> Elaborator<'context> {
/// in self.type_variables to default it later.
pub(super) fn polymorphic_integer_or_field(&mut self) -> Type {
let typ = Type::polymorphic_integer_or_field(self.interner);
self.type_variables.push(typ.clone());
self.push_type_variable(typ.clone());
typ
}

/// Return a fresh integer type variable and log it
/// in self.type_variables to default it later.
pub(super) fn polymorphic_integer(&mut self) -> Type {
let typ = Type::polymorphic_integer(self.interner);
self.type_variables.push(typ.clone());
self.push_type_variable(typ.clone());
typ
}

Expand Down Expand Up @@ -1410,26 +1410,10 @@ impl<'context> Elaborator<'context> {
Err(erroring_constraints) => {
if erroring_constraints.is_empty() {
self.push_err(TypeCheckError::TypeAnnotationsNeeded { span });
} else {
// Don't show any errors where try_get_trait returns None.
// This can happen if a trait is used that was never declared.
let constraints = erroring_constraints
.into_iter()
.map(|constraint| {
let r#trait = self.interner.try_get_trait(constraint.trait_id)?;
let mut name = r#trait.name.to_string();
if !constraint.trait_generics.is_empty() {
let generics =
vecmap(&constraint.trait_generics, ToString::to_string);
name += &format!("<{}>", generics.join(", "));
}
Some((constraint.typ, name))
})
.collect::<Option<Vec<_>>>();

if let Some(constraints) = constraints {
self.push_err(TypeCheckError::NoMatchingImplFound { constraints, span });
}
} else if let Some(error) =
NoMatchingImplFoundError::new(self.interner, erroring_constraints, span)
{
self.push_err(TypeCheckError::NoMatchingImplFound(error));
}
}
}
Expand Down Expand Up @@ -1557,4 +1541,20 @@ impl<'context> Elaborator<'context> {
}
}
}

/// Push a type variable into the current FunctionContext to be defaulted if needed
/// at the end of the earlier of either the current function or the current comptime scope.
fn push_type_variable(&mut self, typ: Type) {
let context = self.function_context.last_mut();
let context = context.expect("The function_context stack should always be non-empty");
context.type_variables.push(typ);
}

/// Push a trait constraint into the current FunctionContext to be solved if needed
/// at the end of the earlier of either the current function or the current comptime scope.
pub fn push_trait_constraint(&mut self, constraint: TraitConstraint, expr_id: ExprId) {
let context = self.function_context.last_mut();
let context = context.expect("The function_context stack should always be non-empty");
context.trait_constraints.push((constraint, expr_id));
}
}
16 changes: 15 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::rc::Rc;

use crate::{
hir::def_collector::dc_crate::CompilationError, parser::ParserError, token::Tokens, Type,
hir::{def_collector::dc_crate::CompilationError, type_check::NoMatchingImplFoundError},
parser::ParserError,
token::Tokens,
Type,
};
use acvm::{acir::AcirField, FieldElement};
use fm::FileId;
Expand Down Expand Up @@ -44,6 +47,8 @@ pub enum InterpreterError {
FailedToParseMacro { error: ParserError, tokens: Rc<Tokens>, rule: &'static str, file: FileId },
UnsupportedTopLevelItemUnquote { item: String, location: Location },
NonComptimeFnCallInSameCrate { function: String, location: Location },
NoImpl { location: Location },
NoMatchingImplFound { error: NoMatchingImplFoundError, file: FileId },

Unimplemented { item: String, location: Location },

Expand Down Expand Up @@ -106,11 +111,15 @@ impl InterpreterError {
| InterpreterError::UnsupportedTopLevelItemUnquote { location, .. }
| InterpreterError::NonComptimeFnCallInSameCrate { location, .. }
| InterpreterError::Unimplemented { location, .. }
| InterpreterError::NoImpl { location, .. }
| InterpreterError::BreakNotInLoop { location, .. }
| InterpreterError::ContinueNotInLoop { location, .. } => *location,
InterpreterError::FailedToParseMacro { error, file, .. } => {
Location::new(error.span(), *file)
}
InterpreterError::NoMatchingImplFound { error, file } => {
Location::new(error.span, *file)
}
InterpreterError::Break | InterpreterError::Continue => {
panic!("Tried to get the location of Break/Continue error!")
}
Expand Down Expand Up @@ -324,6 +333,11 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic {
let msg = "There is no loop to continue!".into();
CustomDiagnostic::simple_error(msg, String::new(), location.span)
}
InterpreterError::NoImpl { location } => {
let msg = "No impl found due to prior type error".into();
CustomDiagnostic::simple_error(msg, String::new(), location.span)
}
InterpreterError::NoMatchingImplFound { error, .. } => error.into(),
InterpreterError::Break => unreachable!("Uncaught InterpreterError::Break"),
InterpreterError::Continue => unreachable!("Uncaught InterpreterError::Continue"),
}
Expand Down
Loading

0 comments on commit 8aa5b2e

Please sign in to comment.