Skip to content

Commit

Permalink
feat: Allow type aliases to reference other aliases (#4353)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

## Summary\*

Allows type aliases to reference other aliases, using the dependency
graph to error if there are ever cycles.

To prevent infinite recursion in the type checker, aliases now have
their own `Type::Alias` node which is used
to set the inner aliased type to `Type::Error` in the case of a cycle to
break the cycle.

## Additional Context

Example error:

```
error: Dependency cycle found
  ┌─ /home/user/Code/Noir/noir/short/src/main.nr:2:1
  │
2 │ type B = A;
  │ ---------- 'B' recursively depends on itself: B -> A -> B
  │
```

## Documentation\*

Check one:
- [ ] No documentation needed.
- [x] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: kevaundray <[email protected]>
  • Loading branch information
jfecher and kevaundray authored Feb 13, 2024
1 parent 39af6cc commit c44ef14
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 55 deletions.
44 changes: 30 additions & 14 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_F
use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern};
use crate::node_interner::{
DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, NodeInterner, StmtId,
StructId, TraitId, TraitImplId, TraitMethodId,
StructId, TraitId, TraitImplId, TraitMethodId, TypeAliasId,
};
use crate::{
hir::{def_map::CrateDefMap, resolution::path_resolver::PathResolver},
Expand All @@ -39,9 +39,9 @@ use crate::{
use crate::{
ArrayLiteral, ContractFunctionType, Distinctness, ForRange, FunctionDefinition,
FunctionReturnType, FunctionVisibility, Generics, LValue, NoirStruct, NoirTypeAlias, Param,
Path, PathKind, Pattern, Shared, StructType, Type, TypeAliasType, TypeVariable,
TypeVariableKind, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType,
UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT,
Path, PathKind, Pattern, Shared, StructType, Type, TypeAlias, TypeVariable, TypeVariableKind,
UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use fm::FileId;
use iter_extended::vecmap;
Expand Down Expand Up @@ -573,16 +573,19 @@ impl<'a> Resolver<'a> {
let span = path.span();
let mut args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables));

if let Some(type_alias_type) = self.lookup_type_alias(path.clone()) {
let expected_generic_count = type_alias_type.generics.len();
let type_alias_string = type_alias_type.to_string();
let id = type_alias_type.id;
if let Some(type_alias) = self.lookup_type_alias(path.clone()) {
let type_alias = type_alias.borrow();
let expected_generic_count = type_alias.generics.len();
let type_alias_string = type_alias.to_string();
let id = type_alias.id;

self.verify_generics_count(expected_generic_count, &mut args, span, || {
type_alias_string
});

let result = self.interner.get_type_alias(id).get_type(&args);
if let Some(item) = self.current_item {
self.interner.add_type_alias_dependency(item, id);
}

// Collecting Type Alias references [Location]s to be used by LSP in order
// to resolve the definition of the type alias
Expand All @@ -593,9 +596,8 @@ impl<'a> Resolver<'a> {
// equal to another type alias. Fixing this fully requires an analysis to create a DFG
// of definition ordering, but for now we have an explicit check here so that we at
// least issue an error that the type was not found instead of silently passing.
if result != Type::Error {
return result;
}
let alias = self.interner.get_type_alias(id);
return Type::Alias(alias, args);
}

match self.lookup_struct_or_error(path) {
Expand Down Expand Up @@ -752,12 +754,15 @@ impl<'a> Resolver<'a> {
resolved_type
}

pub fn resolve_type_aliases(
pub fn resolve_type_alias(
mut self,
unresolved: NoirTypeAlias,
alias_id: TypeAliasId,
) -> (Type, Generics, Vec<ResolverError>) {
let generics = self.add_generics(&unresolved.generics);
self.resolve_local_globals();

self.current_item = Some(DependencyId::Alias(alias_id));
let typ = self.resolve_type(unresolved.typ);

(typ, generics, self.errors)
Expand Down Expand Up @@ -1120,6 +1125,17 @@ impl<'a> Resolver<'a> {
}
}
}
Type::Alias(alias, generics) => {
for (i, generic) in generics.iter().enumerate() {
if let Type::NamedGeneric(type_variable, name) = generic {
if alias.borrow().generic_is_numeric(i) {
found.insert(name.to_string(), type_variable.clone());
}
} else {
Self::find_numeric_generics_in_type(generic, found);
}
}
}
Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found),
Type::String(length) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
Expand Down Expand Up @@ -1791,7 +1807,7 @@ impl<'a> Resolver<'a> {
}
}

fn lookup_type_alias(&mut self, path: Path) -> Option<&TypeAliasType> {
fn lookup_type_alias(&mut self, path: Path) -> Option<Shared<TypeAlias>> {
self.lookup(path).ok().map(|id| self.interner.get_type_alias(id))
}

Expand Down
6 changes: 3 additions & 3 deletions compiler/noirc_frontend/src/hir/resolution/type_aliases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ pub(crate) fn resolve_type_aliases(
crate_id: CrateId,
) -> Vec<(CompilationError, FileId)> {
let mut errors: Vec<(CompilationError, FileId)> = vec![];
for (type_id, unresolved_typ) in type_aliases {
for (alias_id, unresolved_typ) in type_aliases {
let path_resolver = StandardPathResolver::new(ModuleId {
local_id: unresolved_typ.module_id,
krate: crate_id,
});
let file = unresolved_typ.file_id;
let (typ, generics, resolver_errors) =
Resolver::new(&mut context.def_interner, &path_resolver, &context.def_maps, file)
.resolve_type_aliases(unresolved_typ.type_alias_def);
.resolve_type_alias(unresolved_typ.type_alias_def, alias_id);
errors.extend(resolver_errors.iter().cloned().map(|e| (e.into(), file)));
context.def_interner.set_type_alias(type_id, typ, generics);
context.def_interner.set_type_alias(alias_id, typ, generics);
}
errors
}
8 changes: 8 additions & 0 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,10 @@ impl<'interner> TypeChecker<'interner> {
})
}
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.comparator_operand_type_rules(&alias, other, op, span)
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
return Err(TypeCheckError::IntegerSignedness {
Expand Down Expand Up @@ -1141,6 +1145,10 @@ impl<'interner> TypeChecker<'interner> {
})
}
}
(Alias(alias, args), other) | (other, Alias(alias, args)) => {
let alias = alias.borrow().get_type(args);
self.infix_operand_type_rules(&alias, op, other, span)
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if sign_x != sign_y {
return Err(TypeCheckError::IntegerSignedness {
Expand Down
6 changes: 3 additions & 3 deletions compiler/noirc_frontend/src/hir/type_check/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<'interner> TypeChecker<'interner> {
match pattern {
HirPattern::Identifier(ident) => self.interner.push_definition_type(ident.id, typ),
HirPattern::Mutable(pattern, _) => self.bind_pattern(pattern, typ),
HirPattern::Tuple(fields, location) => match typ {
HirPattern::Tuple(fields, location) => match typ.follow_bindings() {
Type::Tuple(field_types) if field_types.len() == fields.len() => {
for (field, field_type) in fields.iter().zip(field_types) {
self.bind_pattern(field, field_type);
Expand All @@ -120,12 +120,12 @@ impl<'interner> TypeChecker<'interner> {
source: Source::Assignment,
});

if let Type::Struct(struct_type, generics) = struct_type {
if let Type::Struct(struct_type, generics) = struct_type.follow_bindings() {
let struct_type = struct_type.borrow();

for (field_name, field_pattern) in fields {
if let Some((type_field, _)) =
struct_type.get_field(&field_name.0.contents, generics)
struct_type.get_field(&field_name.0.contents, &generics)
{
self.bind_pattern(field_pattern, type_field);
}
Expand Down
Loading

0 comments on commit c44ef14

Please sign in to comment.