From 3214d25ba45d72e2bdb82ebceea5598c2d4b28df Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 25 Oct 2023 11:53:06 -0500 Subject: [PATCH] feat: Allow a trait to be implemented multiple times for the same struct (#3292) --- .../src/hir/def_collector/dc_crate.rs | 92 ++++++++----------- .../src/hir/def_collector/dc_mod.rs | 15 ++- compiler/noirc_frontend/src/node_interner.rs | 10 +- compiler/noirc_frontend/src/tests.rs | 28 +++--- .../trait_generics/Nargo.toml | 7 ++ .../trait_generics/src/main.nr | 27 ++++++ 6 files changed, 108 insertions(+), 71 deletions(-) create mode 100644 tooling/nargo_cli/tests/compile_success_empty/trait_generics/Nargo.toml create mode 100644 tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index aa2bedb37b7..50618c1ca8b 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -1,7 +1,7 @@ use super::dc_mod::collect_defs; use super::errors::{DefCollectorErrorKind, DuplicateType}; use crate::graph::CrateId; -use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId}; +use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleData, ModuleDefId, ModuleId}; use crate::hir::resolution::errors::ResolverError; use crate::hir::resolution::import::PathResolutionError; use crate::hir::resolution::path_resolver::PathResolver; @@ -126,7 +126,7 @@ pub struct DefCollector { pub enum CompilationError { ParseError(ParserError), DefinitionError(DefCollectorErrorKind), - ResolveError(ResolverError), + ResolverError(ResolverError), TypeError(TypeCheckError), } @@ -135,7 +135,7 @@ impl From for CustomDiagnostic { match value { CompilationError::ParseError(error) => error.into(), CompilationError::DefinitionError(error) => error.into(), - CompilationError::ResolveError(error) => error.into(), + CompilationError::ResolverError(error) => error.into(), CompilationError::TypeError(error) => error.into(), } } @@ -155,7 +155,7 @@ impl From for CompilationError { impl From for CompilationError { fn from(value: ResolverError) -> Self { - CompilationError::ResolveError(value) + CompilationError::ResolverError(value) } } impl From for CompilationError { @@ -296,12 +296,6 @@ impl DefCollector { // globals will need to reference the struct type they're initialized to to ensure they are valid. resolved_globals.extend(resolve_globals(context, other_globals, crate_id)); - // Before we resolve any function symbols we must go through our impls and - // re-collect the methods within into their proper module. This cannot be - // done before resolution since we need to be able to resolve the type of the - // impl since that determines the module we should collect into. - errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls)); - // Bind trait impls to their trait. Collect trait functions, that have a // default implementation, which hasn't been overridden. errors.extend(collect_trait_impls( @@ -310,6 +304,15 @@ impl DefCollector { &mut def_collector.collected_traits_impls, )); + // Before we resolve any function symbols we must go through our impls and + // re-collect the methods within into their proper module. This cannot be + // done before resolution since we need to be able to resolve the type of the + // impl since that determines the module we should collect into. + // + // These are resolved after trait impls so that struct methods are chosen + // over trait methods if there are name conflicts. + errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls)); + // Lower each function in the crate. This is now possible since imports have been resolved let file_func_ids = resolve_free_functions( &mut context.def_interner, @@ -377,7 +380,6 @@ fn collect_impls( if let Some(struct_type) = get_struct_type(&typ) { let struct_type = struct_type.borrow(); - let type_module = struct_type.id.local_module_id(); // `impl`s are only allowed on types defined within the current crate if struct_type.id.krate() != crate_id { @@ -391,7 +393,7 @@ fn collect_impls( // Grab the module defined by the struct type. Note that impls are a case // where the module the methods are added to is not the same as the module // they are resolved in. - let module = &mut def_maps.get_mut(&crate_id).unwrap().modules[type_module.0]; + let module = get_module_mut(def_maps, struct_type.id.module_id()); for (_, method_id, method) in &unresolved.functions { // If this method was already declared, remove it from the module so it cannot @@ -413,6 +415,13 @@ fn collect_impls( errors } +fn get_module_mut( + def_maps: &mut BTreeMap, + module: ModuleId, +) -> &mut ModuleData { + &mut def_maps.get_mut(&module.krate).unwrap().modules[module.local_id.0] +} + fn collect_trait_impl_methods( interner: &mut NodeInterner, def_maps: &BTreeMap, @@ -494,25 +503,6 @@ fn collect_trait_impl_methods( errors } -fn add_method_to_struct_namespace( - current_def_map: &mut CrateDefMap, - struct_type: &Shared, - func_id: FuncId, - name_ident: &Ident, - trait_id: TraitId, -) -> Result<(), DefCollectorErrorKind> { - let struct_type = struct_type.borrow(); - let type_module = struct_type.id.local_module_id(); - let module = &mut current_def_map.modules[type_module.0]; - module.declare_trait_function(name_ident.clone(), func_id, trait_id).map_err( - |(first_def, second_def)| DefCollectorErrorKind::Duplicate { - typ: DuplicateType::TraitImplementation, - first_def, - second_def, - }, - ) -} - fn collect_trait_impl( context: &mut Context, crate_id: CrateId, @@ -535,28 +525,24 @@ fn collect_trait_impl( if let Some(trait_id) = trait_impl.trait_id { errors .extend(collect_trait_impl_methods(interner, def_maps, crate_id, trait_id, trait_impl)); - for (_, func_id, ast) in &trait_impl.methods.functions { - let file = def_maps[&crate_id].file_id(trait_impl.module_id); - let path_resolver = StandardPathResolver::new(module); - let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file); - resolver.add_generics(&ast.def.generics); - let typ = resolver.resolve_type(unresolved_type.clone()); - - if let Some(struct_type) = get_struct_type(&typ) { - errors.extend(take_errors(trait_impl.file_id, resolver)); - let current_def_map = def_maps.get_mut(&struct_type.borrow().id.krate()).unwrap(); - match add_method_to_struct_namespace( - current_def_map, - struct_type, - *func_id, - ast.name_ident(), - trait_id, - ) { - Ok(()) => {} - Err(err) => { - errors.push((err.into(), trait_impl.file_id)); - } + let path_resolver = StandardPathResolver::new(module); + let file = def_maps[&crate_id].file_id(trait_impl.module_id); + let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file); + let typ = resolver.resolve_type(unresolved_type); + errors.extend(take_errors(trait_impl.file_id, resolver)); + + if let Some(struct_type) = get_struct_type(&typ) { + let struct_type = struct_type.borrow(); + let module = get_module_mut(def_maps, struct_type.id.module_id()); + + for (_, method_id, method) in &trait_impl.methods.functions { + // If this method was already declared, remove it from the module so it cannot + // be accessed with the `TypeName::method` syntax. We'll check later whether the + // object types in each method overlap or not. If they do, we issue an error. + // If not, that is specialization which is allowed. + if module.declare_function(method.name_ident().clone(), *method_id).is_err() { + module.remove_function(method.name_ident()); } } } @@ -841,7 +827,7 @@ fn take_errors_filter_self_not_resolved( } fn take_errors(file_id: FileId, resolver: Resolver<'_>) -> Vec<(CompilationError, FileId)> { - resolver.take_errors().iter().cloned().map(|e| (e.into(), file_id)).collect() + vecmap(resolver.take_errors(), |e| (e.into(), file_id)) } /// Create the mappings from TypeId -> TraitType diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index ed48d7fbb51..4af910b6f84 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -7,7 +7,7 @@ use noirc_errors::Location; use crate::{ graph::CrateId, hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait}, - node_interner::{TraitId, TypeAliasId}, + node_interner::{FunctionModifiers, TraitId, TypeAliasId}, parser::{SortedModule, SortedSubModule}, FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, TraitImplItem, TraitItem, TypeImpl, @@ -378,11 +378,22 @@ impl<'a> ModCollector<'a> { body, } => { let func_id = context.def_interner.push_empty_fn(); + let modifiers = FunctionModifiers { + name: name.to_string(), + visibility: crate::FunctionVisibility::Public, + // TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629 + attributes: crate::token::Attributes::empty(), + is_unconstrained: false, + contract_function_type: None, + is_internal: None, + }; + + context.def_interner.push_function_definition(func_id, modifiers, id.0); + match self.def_collector.def_map.modules[id.0.local_id.0] .declare_function(name.clone(), func_id) { Ok(()) => { - // TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629 if let Some(body) = body { let impl_method = NoirFunction::normal(FunctionDefinition::normal( diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 514a65d37a2..5febc3f4259 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -618,9 +618,10 @@ impl NodeInterner { #[cfg(test)] pub fn push_test_function_definition(&mut self, name: String) -> FuncId { let id = self.push_fn(HirFunction::empty()); - let modifiers = FunctionModifiers::new(); + let mut modifiers = FunctionModifiers::new(); + modifiers.name = name; let module = ModuleId::dummy_id(); - self.push_function_definition(name, id, modifiers, module); + self.push_function_definition(id, modifiers, module); id } @@ -631,7 +632,6 @@ impl NodeInterner { module: ModuleId, ) -> DefinitionId { use ContractFunctionType::*; - let name = function.name.0.contents.clone(); // We're filling in contract_function_type and is_internal now, but these will be verified // later during name resolution. @@ -643,16 +643,16 @@ impl NodeInterner { contract_function_type: Some(if function.is_open { Open } else { Secret }), is_internal: Some(function.is_internal), }; - self.push_function_definition(name, id, modifiers, module) + self.push_function_definition(id, modifiers, module) } pub fn push_function_definition( &mut self, - name: String, func: FuncId, modifiers: FunctionModifiers, module: ModuleId, ) -> DefinitionId { + let name = modifiers.name.clone(); self.function_modifiers.insert(func, modifiers); self.function_modules.insert(func, module); self.push_definition(name, false, DefinitionKind::Function(func)) diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 91fce8d8862..b7953d0797c 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -461,7 +461,7 @@ mod test { for (err, _file_id) in errors { match &err { - CompilationError::ResolveError(ResolverError::PathResolutionError( + CompilationError::ResolverError(ResolverError::PathResolutionError( PathResolutionError::Unresolved(ident), )) => { assert_eq!(ident, "NotAType"); @@ -533,12 +533,11 @@ mod test { } } - fn main() { - } + fn main() {} "; let errors = get_program_errors(src); assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); for (err, _file_id) in errors { match &err { CompilationError::DefinitionError( @@ -546,6 +545,12 @@ mod test { ) => { assert_eq!(trait_path.as_string(), "Default"); } + CompilationError::ResolverError(ResolverError::Expected { + expected, got, .. + }) => { + assert_eq!(expected, "type"); + assert_eq!(got, "function"); + } _ => { panic!("No other errors are expected! Found = {:?}", err); } @@ -810,7 +815,7 @@ mod test { assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); // It should be regarding the unused variable match &errors[0].0 { - CompilationError::ResolveError(ResolverError::UnusedVariable { ident }) => { + CompilationError::ResolverError(ResolverError::UnusedVariable { ident }) => { assert_eq!(&ident.0.contents, "y"); } _ => unreachable!("we should only have an unused var error"), @@ -829,7 +834,7 @@ mod test { assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); // It should be regarding the unresolved var `z` (Maybe change to undeclared and special case) match &errors[0].0 { - CompilationError::ResolveError(ResolverError::VariableNotDeclared { + CompilationError::ResolverError(ResolverError::VariableNotDeclared { name, span: _, }) => assert_eq!(name, "z"), @@ -848,7 +853,7 @@ mod test { assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); for (compilation_error, _file_id) in errors { match compilation_error { - CompilationError::ResolveError(err) => { + CompilationError::ResolverError(err) => { match err { ResolverError::PathResolutionError(PathResolutionError::Unresolved( name, @@ -892,7 +897,7 @@ mod test { // `foo::bar` does not exist for (compilation_error, _file_id) in errors { match compilation_error { - CompilationError::ResolveError(err) => { + CompilationError::ResolverError(err) => { match err { ResolverError::UnusedVariable { ident } => { assert_eq!(&ident.0.contents, "z"); @@ -1069,12 +1074,13 @@ mod test { for (err, _file_id) in errors { match &err { - CompilationError::ResolveError(ResolverError::VariableNotDeclared { - name, .. + CompilationError::ResolverError(ResolverError::VariableNotDeclared { + name, + .. }) => { assert_eq!(name, "i"); } - CompilationError::ResolveError(ResolverError::NumericConstantInFormatString { + CompilationError::ResolverError(ResolverError::NumericConstantInFormatString { name, .. }) => { diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_generics/Nargo.toml b/tooling/nargo_cli/tests/compile_success_empty/trait_generics/Nargo.toml new file mode 100644 index 00000000000..9da56eebf35 --- /dev/null +++ b/tooling/nargo_cli/tests/compile_success_empty/trait_generics/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "trait_generics" +type = "bin" +authors = [""] +compiler_version = "0.10.5" + +[dependencies] diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr new file mode 100644 index 00000000000..c44366c006e --- /dev/null +++ b/tooling/nargo_cli/tests/compile_success_empty/trait_generics/src/main.nr @@ -0,0 +1,27 @@ + +struct Empty {} + +trait Foo { + fn foo(self) -> u32; +} + +impl Foo for Empty { + fn foo(_self: Self) -> u32 { 32 } +} + +impl Foo for Empty { + fn foo(_self: Self) -> u32 { 64 } +} + +fn main() { + let x: Empty = Empty {}; + let y: Empty = Empty {}; + let z = Empty {}; + + assert(x.foo() == 32); + assert(y.foo() == 64); + + // Types matching multiple impls will currently choose + // the first matching one instead of erroring + assert(z.foo() == 32); +}