From 0b5abf9e77618e9150b0dd322b3adbbabf93b2cf Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 07:17:55 -0300 Subject: [PATCH 1/8] Introduce the Visitor pattern and use it in SignatureHelp --- compiler/noirc_frontend/src/ast/mod.rs | 3 + compiler/noirc_frontend/src/ast/visitor.rs | 801 ++++++++++++++++++ tooling/lsp/src/requests/signature_help.rs | 181 ++-- .../src/requests/signature_help/traversal.rs | 301 ------- 4 files changed, 895 insertions(+), 391 deletions(-) create mode 100644 compiler/noirc_frontend/src/ast/visitor.rs delete mode 100644 tooling/lsp/src/requests/signature_help/traversal.rs diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index b10e58aac0c..e63222bfc87 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -10,6 +10,9 @@ mod statement; mod structure; mod traits; mod type_alias; +mod visitor; + +pub use visitor::Visitor; pub use expression::*; pub use function::*; diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs new file mode 100644 index 00000000000..1186bef2bfb --- /dev/null +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -0,0 +1,801 @@ +use noirc_errors::Span; + +use crate::{ + ast::{ + ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression, + CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, + ForLoopStatement, ForRange, Ident, IfExpression, IndexExpression, InfixExpression, LValue, + Lambda, LetStatement, Literal, MemberAccessExpression, MethodCallExpression, + ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, + PrefixExpression, Statement, StatementKind, TraitImplItem, TraitItem, TypeImpl, UseTree, + UseTreeKind, + }, + parser::{Item, ItemKind, ParsedSubModule}, + ParsedModule, +}; + +/// Implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for Noir's AST. +/// +/// In this implementation, methods must return a bool: +/// - true means children must be visited +/// - false means children must not be visited, either because the visitor implementation +/// will visit children of interest manually, or because no children are of interest +pub trait Visitor { + fn visit_parsed_module(&mut self, _: &ParsedModule) -> bool { + true + } + + fn visit_item(&mut self, _: &Item) -> bool { + true + } + + fn visit_parsed_submodule(&mut self, _: &ParsedSubModule, _: Span) -> bool { + true + } + + fn visit_noir_function(&mut self, _: &NoirFunction, _: Option) -> bool { + true + } + + fn visit_noir_trait_impl(&mut self, _: &NoirTraitImpl, _: Span) -> bool { + true + } + + fn visit_type_impl(&mut self, _: &TypeImpl, _: Span) -> bool { + true + } + + fn visit_trait_impl_item(&mut self, _: &TraitImplItem) -> bool { + true + } + + fn visit_noir_trait(&mut self, _: &NoirTrait, _: Span) -> bool { + true + } + + fn visit_trait_item(&mut self, _: &TraitItem) -> bool { + true + } + + fn visit_use_tree(&mut self, _: &UseTree) -> bool { + true + } + + fn visit_use_tree_path(&mut self, _: &UseTree, _ident: &Ident, _alias: &Option) {} + + fn visit_use_tree_list(&mut self, _: &UseTree, _: &[UseTree]) -> bool { + true + } + + fn visit_noir_struct(&mut self, _: &NoirStruct, _: Span) {} + + fn visit_noir_type_alias(&mut self, _: &NoirTypeAlias, _: Span) {} + + fn visit_module_declaration(&mut self, _: &ModuleDeclaration, _: Span) {} + + fn visit_expression(&mut self, _: &Expression) -> bool { + true + } + + fn visit_literal(&mut self, _: &Literal, _: Span) -> bool { + true + } + + fn visit_block_expression(&mut self, _: &BlockExpression, _: Option) -> bool { + true + } + + fn visit_prefix_expression(&mut self, _: &PrefixExpression, _: Span) -> bool { + true + } + + fn visit_index_expression(&mut self, _: &IndexExpression, _: Span) -> bool { + true + } + + fn visit_call_expression(&mut self, _: &CallExpression, _: Span) -> bool { + true + } + + fn visit_method_call_expression(&mut self, _: &MethodCallExpression, _: Span) -> bool { + true + } + + fn visit_constructor_expression(&mut self, _: &ConstructorExpression, _: Span) -> bool { + true + } + + fn visit_member_access_expression(&mut self, _: &MemberAccessExpression, _: Span) -> bool { + true + } + + fn visit_cast_expression(&mut self, _: &CastExpression, _: Span) -> bool { + true + } + + fn visit_infix_expression(&mut self, _: &InfixExpression, _: Span) -> bool { + true + } + + fn visit_if_expression(&mut self, _: &IfExpression, _: Span) -> bool { + true + } + + fn visit_tuple(&mut self, _: &[Expression], _: Span) -> bool { + true + } + + fn visit_parenthesized(&mut self, _: &Expression, _: Span) -> bool { + true + } + + fn visit_unquote(&mut self, _: &Expression, _: Span) -> bool { + true + } + + fn visit_comptime_expression(&mut self, _: &BlockExpression, _: Span) -> bool { + true + } + + fn visit_unsafe(&mut self, _: &BlockExpression, _: Span) -> bool { + true + } + + fn visit_variable(&mut self, _: &Path, _: Span) {} + + fn visit_lambda(&mut self, _: &Lambda, _: Span) -> bool { + true + } + + fn visit_array_literal(&mut self, _: &ArrayLiteral) -> bool { + true + } + + fn visit_statement(&mut self, _: &Statement) -> bool { + true + } + + fn visit_global(&mut self, _: &LetStatement, _: Span) -> bool { + true + } + + fn visit_let_statement(&mut self, _: &LetStatement) -> bool { + true + } + + fn visit_constrain_statement(&mut self, _: &ConstrainStatement) -> bool { + true + } + + fn visit_assign_statement(&mut self, _: &AssignStatement) -> bool { + true + } + + fn visit_for_loop_statement(&mut self, _: &ForLoopStatement) -> bool { + true + } + + fn visit_comptime_statement(&mut self, _: &Statement) -> bool { + true + } + + fn visit_lvalue(&mut self, _: &LValue) -> bool { + true + } + + fn visit_lvalue_ident(&mut self, _: &Ident) {} + + fn visit_for_range(&mut self, _: &ForRange) -> bool { + true + } + + fn visit_as_trait_path(&mut self, _: &AsTraitPath, _: Span) {} +} + +impl ParsedModule { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_parsed_module(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for item in &self.items { + item.accept(visitor); + } + } +} + +impl Item { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_item(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + ItemKind::Submodules(parsed_sub_module) => { + parsed_sub_module.accept(self.span, visitor); + } + ItemKind::Function(noir_function) => noir_function.accept(Some(self.span), visitor), + ItemKind::TraitImpl(noir_trait_impl) => { + noir_trait_impl.accept(self.span, visitor); + } + ItemKind::Impl(type_impl) => type_impl.accept(self.span, visitor), + ItemKind::Global(let_statement) => { + if visitor.visit_global(let_statement, self.span) { + let_statement.accept(visitor) + } + } + ItemKind::Trait(noir_trait) => noir_trait.accept(self.span, visitor), + ItemKind::Import(use_tree) => use_tree.accept(visitor), + ItemKind::TypeAlias(noir_type_alias) => noir_type_alias.accept(self.span, visitor), + ItemKind::Struct(noir_struct) => noir_struct.accept(self.span, visitor), + ItemKind::ModuleDecl(module_declaration) => { + module_declaration.accept(self.span, visitor) + } + } + } +} + +impl ParsedSubModule { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_parsed_submodule(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.contents.accept(visitor); + } +} + +impl NoirFunction { + pub fn accept(&self, span: Option, visitor: &mut impl Visitor) { + if visitor.visit_noir_function(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.def.body.accept(None, visitor); + } +} + +impl NoirTraitImpl { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_noir_trait_impl(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for item in &self.items { + item.accept(visitor); + } + } +} + +impl TraitImplItem { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_trait_impl_item(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + TraitImplItem::Function(noir_function) => noir_function.accept(None, visitor), + TraitImplItem::Constant(..) => (), + TraitImplItem::Type { .. } => (), + } + } +} + +impl TypeImpl { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_type_impl(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for (method, span) in &self.methods { + method.accept(Some(*span), visitor); + } + } +} + +impl NoirTrait { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_noir_trait(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for item in &self.items { + item.accept(visitor); + } + } +} + +impl TraitItem { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_trait_item(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + TraitItem::Function { + name: _, + generics: _, + parameters: _, + return_type: _, + where_clause: _, + body, + } => { + if let Some(body) = body { + body.accept(None, visitor); + } + } + TraitItem::Constant { name: _, typ: _, default_value } => { + if let Some(default_value) = default_value { + default_value.accept(visitor); + } + } + TraitItem::Type { name: _ } => (), + } + } +} + +impl UseTree { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_use_tree(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + UseTreeKind::Path(ident, alias) => visitor.visit_use_tree_path(self, ident, alias), + UseTreeKind::List(use_trees) => { + if visitor.visit_use_tree_list(self, use_trees) { + for use_tree in use_trees { + use_tree.accept(visitor); + } + } + } + } + } +} + +impl NoirStruct { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + visitor.visit_noir_struct(self, span); + } +} + +impl NoirTypeAlias { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + visitor.visit_noir_type_alias(self, span); + } +} + +impl ModuleDeclaration { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + visitor.visit_module_declaration(self, span); + } +} + +impl Expression { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_expression(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + ExpressionKind::Literal(literal) => literal.accept(self.span, visitor), + ExpressionKind::Block(block_expression) => { + block_expression.accept(Some(self.span), visitor); + } + ExpressionKind::Prefix(prefix_expression) => { + prefix_expression.accept(self.span, visitor); + } + ExpressionKind::Index(index_expression) => { + index_expression.accept(self.span, visitor); + } + ExpressionKind::Call(call_expression) => { + call_expression.accept(self.span, visitor); + } + ExpressionKind::MethodCall(method_call_expression) => { + method_call_expression.accept(self.span, visitor); + } + ExpressionKind::Constructor(constructor_expression) => { + constructor_expression.accept(self.span, visitor); + } + ExpressionKind::MemberAccess(member_access_expression) => { + member_access_expression.accept(self.span, visitor); + } + ExpressionKind::Cast(cast_expression) => { + cast_expression.accept(self.span, visitor); + } + ExpressionKind::Infix(infix_expression) => { + infix_expression.accept(self.span, visitor); + } + ExpressionKind::If(if_expression) => { + if_expression.accept(self.span, visitor); + } + ExpressionKind::Tuple(expressions) => { + if visitor.visit_tuple(expressions, self.span) { + visit_expressions(expressions, visitor); + } + } + ExpressionKind::Lambda(lambda) => lambda.accept(self.span, visitor), + ExpressionKind::Parenthesized(expression) => { + if visitor.visit_parenthesized(expression, self.span) { + expression.accept(visitor); + } + } + ExpressionKind::Unquote(expression) => { + if visitor.visit_unquote(expression, self.span) { + expression.accept(visitor); + } + } + ExpressionKind::Comptime(block_expression, _) => { + if visitor.visit_comptime_expression(block_expression, self.span) { + block_expression.accept(None, visitor); + } + } + ExpressionKind::Unsafe(block_expression, _) => { + if visitor.visit_unsafe(block_expression, self.span) { + block_expression.accept(None, visitor); + } + } + ExpressionKind::Variable(path) => { + visitor.visit_variable(path, self.span); + } + ExpressionKind::AsTraitPath(as_trait_path) => { + as_trait_path.accept(self.span, visitor); + } + ExpressionKind::Quote(_) + | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) + | ExpressionKind::Error => (), + } + } +} + +impl Literal { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_literal(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + Literal::Array(array_literal) | Literal::Slice(array_literal) => { + array_literal.accept(visitor); + } + Literal::Bool(_) + | Literal::Integer(_, _) + | Literal::Str(_) + | Literal::RawStr(_, _) + | Literal::FmtStr(_) + | Literal::Unit => (), + } + } +} + +impl BlockExpression { + pub fn accept(&self, span: Option, visitor: &mut impl Visitor) { + if visitor.visit_block_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for statement in &self.statements { + statement.accept(visitor); + } + } +} + +impl PrefixExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_prefix_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.rhs.accept(visitor); + } +} + +impl IndexExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_index_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.collection.accept(visitor); + self.index.accept(visitor); + } +} + +impl CallExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_call_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.func.accept(visitor); + visit_expressions(&self.arguments, visitor); + } +} + +impl MethodCallExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_method_call_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.object.accept(visitor); + visit_expressions(&self.arguments, visitor); + } +} + +impl ConstructorExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_constructor_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for (_field_name, expression) in &self.fields { + expression.accept(visitor); + } + } +} + +impl MemberAccessExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_member_access_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.lhs.accept(visitor); + } +} + +impl CastExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_cast_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.lhs.accept(visitor); + } +} + +impl InfixExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_infix_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.lhs.accept(visitor); + self.rhs.accept(visitor); + } +} + +impl IfExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_if_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.condition.accept(visitor); + self.consequence.accept(visitor); + if let Some(alternative) = &self.alternative { + alternative.accept(visitor); + } + } +} + +impl Lambda { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_lambda(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.body.accept(visitor); + } +} + +impl ArrayLiteral { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_array_literal(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + ArrayLiteral::Standard(expressions) => visit_expressions(expressions, visitor), + ArrayLiteral::Repeated { repeated_element, length } => { + repeated_element.accept(visitor); + length.accept(visitor); + } + } + } +} + +impl Statement { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_statement(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + StatementKind::Let(let_statement) => { + let_statement.accept(visitor); + } + StatementKind::Constrain(constrain_statement) => { + constrain_statement.accept(visitor); + } + StatementKind::Expression(expression) => { + expression.accept(visitor); + } + StatementKind::Assign(assign_statement) => { + assign_statement.accept(visitor); + } + StatementKind::For(for_loop_statement) => { + for_loop_statement.accept(visitor); + } + StatementKind::Comptime(statement) => { + if visitor.visit_comptime_statement(statement) { + statement.accept(visitor); + } + } + StatementKind::Semi(expression) => { + expression.accept(visitor); + } + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), + } + } +} + +impl LetStatement { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_let_statement(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.expression.accept(visitor); + } +} + +impl ConstrainStatement { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_constrain_statement(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.0.accept(visitor); + + if let Some(exp) = &self.1 { + exp.accept(visitor); + } + } +} + +impl AssignStatement { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_assign_statement(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.lvalue.accept(visitor); + self.expression.accept(visitor); + } +} + +impl ForLoopStatement { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_for_loop_statement(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.range.accept(visitor); + self.block.accept(visitor); + } +} + +impl LValue { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_lvalue(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + LValue::Ident(ident) => visitor.visit_lvalue_ident(ident), + LValue::MemberAccess { object, field_name: _, span: _ } => object.accept(visitor), + LValue::Index { array, index, span: _ } => { + array.accept(visitor); + index.accept(visitor); + } + LValue::Dereference(lvalue, _) => lvalue.accept(visitor), + LValue::Interned(..) => (), + } + } +} + +impl ForRange { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_for_range(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + ForRange::Range(start, end) => { + start.accept(visitor); + end.accept(visitor); + } + ForRange::Array(expression) => expression.accept(visitor), + } + } +} + +impl AsTraitPath { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + visitor.visit_as_trait_path(self, span); + } +} + +fn visit_expressions(expressions: &[Expression], visitor: &mut impl Visitor) { + for expression in expressions { + expression.accept(visitor); + } +} diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index 25676b57381..23f57dcac4f 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -9,7 +9,7 @@ use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ CallExpression, ConstrainKind, ConstrainStatement, Expression, ExpressionKind, - FunctionReturnType, MethodCallExpression, + FunctionReturnType, MethodCallExpression, Visitor, }, hir_def::{function::FuncMeta, stmt::HirPattern}, macros_api::NodeInterner, @@ -22,7 +22,6 @@ use crate::{utils, LspState}; use super::process_request; mod tests; -mod traversal; pub(crate) fn on_signature_help_request( state: &mut LspState, @@ -64,98 +63,11 @@ impl<'a> SignatureFinder<'a> { } fn find(&mut self, parsed_module: &ParsedModule) -> Option { - self.find_in_parsed_module(parsed_module); + parsed_module.accept(self); self.signature_help.clone() } - fn find_in_call_expression(&mut self, call_expression: &CallExpression, span: Span) { - self.find_in_expression(&call_expression.func); - self.find_in_expressions(&call_expression.arguments); - - let arguments_span = Span::from(call_expression.func.span.end() + 1..span.end() - 1); - let span = call_expression.func.span; - let name_span = Span::from(span.end() - 1..span.end()); - let has_self = false; - - self.try_compute_signature_help( - &call_expression.arguments, - arguments_span, - name_span, - has_self, - ); - } - - fn find_in_method_call_expression( - &mut self, - method_call_expression: &MethodCallExpression, - span: Span, - ) { - self.find_in_expression(&method_call_expression.object); - self.find_in_expressions(&method_call_expression.arguments); - - let arguments_span = - Span::from(method_call_expression.method_name.span().end() + 1..span.end() - 1); - let name_span = method_call_expression.method_name.span(); - let has_self = true; - - self.try_compute_signature_help( - &method_call_expression.arguments, - arguments_span, - name_span, - has_self, - ); - } - - pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { - self.find_in_expression(&constrain_statement.0); - - if let Some(exp) = &constrain_statement.1 { - self.find_in_expression(exp); - } - - if self.signature_help.is_some() { - return; - } - - let arguments_span = if let Some(expr) = &constrain_statement.1 { - Span::from(constrain_statement.0.span.start()..expr.span.end()) - } else { - constrain_statement.0.span - }; - - if !self.includes_span(arguments_span) { - return; - } - - match constrain_statement.2 { - ConstrainKind::Assert => { - let mut arguments = vec![constrain_statement.0.clone()]; - if let Some(expr) = &constrain_statement.1 { - arguments.push(expr.clone()); - } - - let active_parameter = self.compute_active_parameter(&arguments); - let signature_information = self.assert_signature_information(active_parameter); - self.set_signature_help(signature_information); - } - ConstrainKind::AssertEq => { - if let ExpressionKind::Infix(infix) = &constrain_statement.0.kind { - let mut arguments = vec![infix.lhs.clone(), infix.rhs.clone()]; - if let Some(expr) = &constrain_statement.1 { - arguments.push(expr.clone()); - } - - let active_parameter = self.compute_active_parameter(&arguments); - let signature_information = - self.assert_eq_signature_information(active_parameter); - self.set_signature_help(signature_information); - } - } - ConstrainKind::Constrain => (), - } - } - fn try_compute_signature_help( &mut self, arguments: &[Expression], @@ -401,3 +313,92 @@ impl<'a> SignatureFinder<'a> { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize } } + +impl<'a> Visitor for SignatureFinder<'a> { + fn visit_call_expression(&mut self, call_expression: &CallExpression, span: Span) -> bool { + call_expression.accept_children(self); + + let arguments_span = Span::from(call_expression.func.span.end() + 1..span.end() - 1); + let span = call_expression.func.span; + let name_span = Span::from(span.end() - 1..span.end()); + let has_self = false; + + self.try_compute_signature_help( + &call_expression.arguments, + arguments_span, + name_span, + has_self, + ); + + false + } + + fn visit_method_call_expression( + &mut self, + method_call_expression: &MethodCallExpression, + span: Span, + ) -> bool { + method_call_expression.accept_children(self); + + let arguments_span = + Span::from(method_call_expression.method_name.span().end() + 1..span.end() - 1); + let name_span = method_call_expression.method_name.span(); + let has_self = true; + + self.try_compute_signature_help( + &method_call_expression.arguments, + arguments_span, + name_span, + has_self, + ); + + false + } + + fn visit_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) -> bool { + constrain_statement.accept_children(self); + + if self.signature_help.is_some() { + return false; + } + + let arguments_span = if let Some(expr) = &constrain_statement.1 { + Span::from(constrain_statement.0.span.start()..expr.span.end()) + } else { + constrain_statement.0.span + }; + + if !self.includes_span(arguments_span) { + return false; + } + + match constrain_statement.2 { + ConstrainKind::Assert => { + let mut arguments = vec![constrain_statement.0.clone()]; + if let Some(expr) = &constrain_statement.1 { + arguments.push(expr.clone()); + } + + let active_parameter = self.compute_active_parameter(&arguments); + let signature_information = self.assert_signature_information(active_parameter); + self.set_signature_help(signature_information); + } + ConstrainKind::AssertEq => { + if let ExpressionKind::Infix(infix) = &constrain_statement.0.kind { + let mut arguments = vec![infix.lhs.clone(), infix.rhs.clone()]; + if let Some(expr) = &constrain_statement.1 { + arguments.push(expr.clone()); + } + + let active_parameter = self.compute_active_parameter(&arguments); + let signature_information = + self.assert_eq_signature_information(active_parameter); + self.set_signature_help(signature_information); + } + } + ConstrainKind::Constrain => (), + } + + false + } +} diff --git a/tooling/lsp/src/requests/signature_help/traversal.rs b/tooling/lsp/src/requests/signature_help/traversal.rs deleted file mode 100644 index e9b050fc965..00000000000 --- a/tooling/lsp/src/requests/signature_help/traversal.rs +++ /dev/null @@ -1,301 +0,0 @@ -/// This file includes the signature help logic that's just about -/// traversing the AST without any additional logic. -use super::SignatureFinder; - -use noirc_frontend::{ - ast::{ - ArrayLiteral, AssignStatement, BlockExpression, CastExpression, ConstructorExpression, - Expression, ExpressionKind, ForLoopStatement, ForRange, IfExpression, IndexExpression, - InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, - NoirFunction, NoirTrait, NoirTraitImpl, Statement, StatementKind, TraitImplItem, TraitItem, - TypeImpl, - }, - parser::{Item, ItemKind}, - ParsedModule, -}; - -impl<'a> SignatureFinder<'a> { - pub(super) fn find_in_parsed_module(&mut self, parsed_module: &ParsedModule) { - for item in &parsed_module.items { - self.find_in_item(item); - } - } - - pub(super) fn find_in_item(&mut self, item: &Item) { - if !self.includes_span(item.span) { - return; - } - - match &item.kind { - ItemKind::Submodules(parsed_sub_module) => { - self.find_in_parsed_module(&parsed_sub_module.contents); - } - ItemKind::Function(noir_function) => self.find_in_noir_function(noir_function), - ItemKind::TraitImpl(noir_trait_impl) => self.find_in_noir_trait_impl(noir_trait_impl), - ItemKind::Impl(type_impl) => self.find_in_type_impl(type_impl), - ItemKind::Global(let_statement) => self.find_in_let_statement(let_statement), - ItemKind::Trait(noir_trait) => self.find_in_noir_trait(noir_trait), - ItemKind::Import(..) - | ItemKind::TypeAlias(_) - | ItemKind::Struct(_) - | ItemKind::ModuleDecl(_) => (), - } - } - - pub(super) fn find_in_noir_function(&mut self, noir_function: &NoirFunction) { - self.find_in_block_expression(&noir_function.def.body); - } - - pub(super) fn find_in_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl) { - for item in &noir_trait_impl.items { - self.find_in_trait_impl_item(item); - } - } - - pub(super) fn find_in_trait_impl_item(&mut self, item: &TraitImplItem) { - match item { - TraitImplItem::Function(noir_function) => self.find_in_noir_function(noir_function), - TraitImplItem::Constant(_, _, _) => (), - TraitImplItem::Type { .. } => (), - } - } - - pub(super) fn find_in_type_impl(&mut self, type_impl: &TypeImpl) { - for (method, span) in &type_impl.methods { - if self.includes_span(*span) { - self.find_in_noir_function(method); - } - } - } - - pub(super) fn find_in_noir_trait(&mut self, noir_trait: &NoirTrait) { - for item in &noir_trait.items { - self.find_in_trait_item(item); - } - } - - pub(super) fn find_in_trait_item(&mut self, trait_item: &TraitItem) { - match trait_item { - TraitItem::Function { body, .. } => { - if let Some(body) = body { - self.find_in_block_expression(body); - }; - } - TraitItem::Constant { default_value, .. } => { - if let Some(default_value) = default_value { - self.find_in_expression(default_value); - } - } - TraitItem::Type { .. } => (), - } - } - - pub(super) fn find_in_block_expression(&mut self, block_expression: &BlockExpression) { - for statement in &block_expression.statements { - if self.includes_span(statement.span) { - self.find_in_statement(statement); - } - } - } - - pub(super) fn find_in_statement(&mut self, statement: &Statement) { - if !self.includes_span(statement.span) { - return; - } - - match &statement.kind { - StatementKind::Let(let_statement) => { - self.find_in_let_statement(let_statement); - } - StatementKind::Constrain(constrain_statement) => { - self.find_in_constrain_statement(constrain_statement); - } - StatementKind::Expression(expression) => { - self.find_in_expression(expression); - } - StatementKind::Assign(assign_statement) => { - self.find_in_assign_statement(assign_statement); - } - StatementKind::For(for_loop_statement) => { - self.find_in_for_loop_statement(for_loop_statement); - } - StatementKind::Comptime(statement) => { - self.find_in_statement(statement); - } - StatementKind::Semi(expression) => { - self.find_in_expression(expression); - } - StatementKind::Break - | StatementKind::Continue - | StatementKind::Interned(_) - | StatementKind::Error => (), - } - } - - pub(super) fn find_in_let_statement(&mut self, let_statement: &LetStatement) { - self.find_in_expression(&let_statement.expression); - } - - pub(super) fn find_in_assign_statement(&mut self, assign_statement: &AssignStatement) { - self.find_in_lvalue(&assign_statement.lvalue); - self.find_in_expression(&assign_statement.expression); - } - - pub(super) fn find_in_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) { - self.find_in_for_range(&for_loop_statement.range); - self.find_in_expression(&for_loop_statement.block); - } - - pub(super) fn find_in_lvalue(&mut self, lvalue: &LValue) { - match lvalue { - LValue::Ident(_) => (), - LValue::MemberAccess { object, field_name: _, span: _ } => self.find_in_lvalue(object), - LValue::Index { array, index, span: _ } => { - self.find_in_lvalue(array); - self.find_in_expression(index); - } - LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), - LValue::Interned(..) => (), - } - } - - pub(super) fn find_in_for_range(&mut self, for_range: &ForRange) { - match for_range { - ForRange::Range(start, end) => { - self.find_in_expression(start); - self.find_in_expression(end); - } - ForRange::Array(expression) => self.find_in_expression(expression), - } - } - - pub(super) fn find_in_expressions(&mut self, expressions: &[Expression]) { - for expression in expressions { - self.find_in_expression(expression); - } - } - - pub(super) fn find_in_expression(&mut self, expression: &Expression) { - match &expression.kind { - ExpressionKind::Literal(literal) => self.find_in_literal(literal), - ExpressionKind::Block(block_expression) => { - self.find_in_block_expression(block_expression); - } - ExpressionKind::Prefix(prefix_expression) => { - self.find_in_expression(&prefix_expression.rhs); - } - ExpressionKind::Index(index_expression) => { - self.find_in_index_expression(index_expression); - } - ExpressionKind::Call(call_expression) => { - self.find_in_call_expression(call_expression, expression.span); - } - ExpressionKind::MethodCall(method_call_expression) => { - self.find_in_method_call_expression(method_call_expression, expression.span); - } - ExpressionKind::Constructor(constructor_expression) => { - self.find_in_constructor_expression(constructor_expression); - } - ExpressionKind::MemberAccess(member_access_expression) => { - self.find_in_member_access_expression(member_access_expression); - } - ExpressionKind::Cast(cast_expression) => { - self.find_in_cast_expression(cast_expression); - } - ExpressionKind::Infix(infix_expression) => { - self.find_in_infix_expression(infix_expression); - } - ExpressionKind::If(if_expression) => { - self.find_in_if_expression(if_expression); - } - ExpressionKind::Tuple(expressions) => { - self.find_in_expressions(expressions); - } - ExpressionKind::Lambda(lambda) => self.find_in_lambda(lambda), - ExpressionKind::Parenthesized(expression) => { - self.find_in_expression(expression); - } - ExpressionKind::Unquote(expression) => { - self.find_in_expression(expression); - } - ExpressionKind::Comptime(block_expression, _) => { - self.find_in_block_expression(block_expression); - } - ExpressionKind::Unsafe(block_expression, _) => { - self.find_in_block_expression(block_expression); - } - ExpressionKind::Variable(_) - | ExpressionKind::AsTraitPath(_) - | ExpressionKind::Quote(_) - | ExpressionKind::Resolved(_) - | ExpressionKind::Interned(_) - | ExpressionKind::Error => (), - } - } - - pub(super) fn find_in_literal(&mut self, literal: &Literal) { - match literal { - Literal::Array(array_literal) => self.find_in_array_literal(array_literal), - Literal::Slice(array_literal) => self.find_in_array_literal(array_literal), - Literal::Bool(_) - | Literal::Integer(_, _) - | Literal::Str(_) - | Literal::RawStr(_, _) - | Literal::FmtStr(_) - | Literal::Unit => (), - } - } - - pub(super) fn find_in_array_literal(&mut self, array_literal: &ArrayLiteral) { - match array_literal { - ArrayLiteral::Standard(expressions) => self.find_in_expressions(expressions), - ArrayLiteral::Repeated { repeated_element, length } => { - self.find_in_expression(repeated_element); - self.find_in_expression(length); - } - } - } - - pub(super) fn find_in_index_expression(&mut self, index_expression: &IndexExpression) { - self.find_in_expression(&index_expression.collection); - self.find_in_expression(&index_expression.index); - } - - pub(super) fn find_in_constructor_expression( - &mut self, - constructor_expression: &ConstructorExpression, - ) { - for (_field_name, expression) in &constructor_expression.fields { - self.find_in_expression(expression); - } - } - - pub(super) fn find_in_member_access_expression( - &mut self, - member_access_expression: &MemberAccessExpression, - ) { - self.find_in_expression(&member_access_expression.lhs); - } - - pub(super) fn find_in_cast_expression(&mut self, cast_expression: &CastExpression) { - self.find_in_expression(&cast_expression.lhs); - } - - pub(super) fn find_in_infix_expression(&mut self, infix_expression: &InfixExpression) { - self.find_in_expression(&infix_expression.lhs); - self.find_in_expression(&infix_expression.rhs); - } - - pub(super) fn find_in_if_expression(&mut self, if_expression: &IfExpression) { - self.find_in_expression(&if_expression.condition); - self.find_in_expression(&if_expression.consequence); - - if let Some(alternative) = &if_expression.alternative { - self.find_in_expression(alternative); - } - } - - pub(super) fn find_in_lambda(&mut self, lambda: &Lambda) { - self.find_in_expression(&lambda.body); - } -} From d450731111f0bf1fdad7b3ddff697f46dc4ef846 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 10:04:43 -0300 Subject: [PATCH 2/8] Use the visitor pattern for LSP completion --- compiler/noirc_frontend/src/ast/visitor.rs | 326 ++- tooling/lsp/src/requests/completion.rs | 1824 ++++++++--------- .../lsp/src/requests/completion/traversal.rs | 126 -- 3 files changed, 1163 insertions(+), 1113 deletions(-) delete mode 100644 tooling/lsp/src/requests/completion/traversal.rs diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index 1186bef2bfb..2ebed378620 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -14,6 +14,11 @@ use crate::{ ParsedModule, }; +use super::{ + FunctionReturnType, GenericTypeArgs, UnresolvedGenerics, UnresolvedTraitConstraint, + UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, +}; + /// Implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for Noir's AST. /// /// In this implementation, methods must return a bool: @@ -57,6 +62,29 @@ pub trait Visitor { true } + fn visit_trait_item_function( + &mut self, + _name: &Ident, + _generics: &UnresolvedGenerics, + _parameters: &[(Ident, UnresolvedType)], + _return_type: &FunctionReturnType, + _where_clause: &[UnresolvedTraitConstraint], + _body: &Option, + ) -> bool { + true + } + + fn visit_trait_item_constant( + &mut self, + _name: &Ident, + _typ: &UnresolvedType, + _default_value: &Option, + ) -> bool { + true + } + + fn visit_trait_item_type(&mut self, _: &Ident) {} + fn visit_use_tree(&mut self, _: &UseTree) -> bool { true } @@ -67,9 +95,13 @@ pub trait Visitor { true } - fn visit_noir_struct(&mut self, _: &NoirStruct, _: Span) {} + fn visit_noir_struct(&mut self, _: &NoirStruct, _: Span) -> bool { + true + } - fn visit_noir_type_alias(&mut self, _: &NoirTypeAlias, _: Span) {} + fn visit_noir_type_alias(&mut self, _: &NoirTypeAlias, _: Span) -> bool { + true + } fn visit_module_declaration(&mut self, _: &ModuleDeclaration, _: Span) {} @@ -141,7 +173,9 @@ pub trait Visitor { true } - fn visit_variable(&mut self, _: &Path, _: Span) {} + fn visit_variable(&mut self, _: &Path, _: Span) -> bool { + true + } fn visit_lambda(&mut self, _: &Lambda, _: Span) -> bool { true @@ -155,6 +189,10 @@ pub trait Visitor { true } + fn visit_import(&mut self, _: &UseTree) -> bool { + true + } + fn visit_global(&mut self, _: &LetStatement, _: Span) -> bool { true } @@ -189,7 +227,71 @@ pub trait Visitor { true } - fn visit_as_trait_path(&mut self, _: &AsTraitPath, _: Span) {} + fn visit_as_trait_path(&mut self, _: &AsTraitPath, _: Span) -> bool { + true + } + + fn visit_unresolved_type(&mut self, _: &UnresolvedType) -> bool { + true + } + + fn visit_array_type( + &mut self, + _: &UnresolvedTypeExpression, + _: &UnresolvedType, + _: Span, + ) -> bool { + true + } + + fn visit_slice_type(&mut self, _: &UnresolvedType, _: Span) -> bool { + true + } + + fn visit_parenthesized_type(&mut self, _: &UnresolvedType, _: Span) -> bool { + true + } + + fn visit_named_type(&mut self, _: &Path, _: &GenericTypeArgs, _: Span) -> bool { + true + } + + fn visit_trait_as_type(&mut self, _: &Path, _: &GenericTypeArgs, _: Span) -> bool { + true + } + + fn visit_mutable_reference_type(&mut self, _: &UnresolvedType, _: Span) -> bool { + true + } + + fn visit_tuple_type(&mut self, _: &[UnresolvedType], _: Span) -> bool { + true + } + + fn visit_function_type( + &mut self, + _args: &[UnresolvedType], + _ret: &UnresolvedType, + _env: &UnresolvedType, + _unconstrained: bool, + _span: Span, + ) -> bool { + true + } + + fn visit_as_trait_path_type(&mut self, _: &AsTraitPath, _: Span) -> bool { + true + } + + fn visit_path(&mut self, _: &Path) {} + + fn visit_generic_type_args(&mut self, _: &GenericTypeArgs) -> bool { + true + } + + fn visit_function_return_type(&mut self, _: &FunctionReturnType) -> bool { + true + } } impl ParsedModule { @@ -229,7 +331,11 @@ impl Item { } } ItemKind::Trait(noir_trait) => noir_trait.accept(self.span, visitor), - ItemKind::Import(use_tree) => use_tree.accept(visitor), + ItemKind::Import(use_tree) => { + if visitor.visit_import(use_tree) { + use_tree.accept(visitor) + } + } ItemKind::TypeAlias(noir_type_alias) => noir_type_alias.accept(self.span, visitor), ItemKind::Struct(noir_struct) => noir_struct.accept(self.span, visitor), ItemKind::ModuleDecl(module_declaration) => { @@ -259,6 +365,10 @@ impl NoirFunction { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + for param in &self.def.parameters { + param.typ.accept(visitor); + } + self.def.body.accept(None, visitor); } } @@ -271,6 +381,9 @@ impl NoirTraitImpl { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.trait_name.accept(visitor); + self.object_type.accept(visitor); + for item in &self.items { item.accept(visitor); } @@ -301,6 +414,8 @@ impl TypeImpl { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.object_type.accept(visitor); + for (method, span) in &self.methods { method.accept(Some(*span), visitor); } @@ -330,24 +445,40 @@ impl TraitItem { pub fn accept_children(&self, visitor: &mut impl Visitor) { match self { - TraitItem::Function { - name: _, - generics: _, - parameters: _, - return_type: _, - where_clause: _, - body, - } => { - if let Some(body) = body { - body.accept(None, visitor); + TraitItem::Function { name, generics, parameters, return_type, where_clause, body } => { + if visitor.visit_trait_item_function( + name, + generics, + parameters, + return_type, + where_clause, + body, + ) { + for (_name, unresolved_type) in parameters { + unresolved_type.accept(visitor); + } + + return_type.accept(visitor); + + for unresolved_trait_constraint in where_clause { + unresolved_trait_constraint.typ.accept(visitor); + } + + if let Some(body) = body { + body.accept(None, visitor); + } } } - TraitItem::Constant { name: _, typ: _, default_value } => { - if let Some(default_value) = default_value { - default_value.accept(visitor); + TraitItem::Constant { name, typ, default_value } => { + if visitor.visit_trait_item_constant(name, typ, default_value) { + typ.accept(visitor); + + if let Some(default_value) = default_value { + default_value.accept(visitor); + } } } - TraitItem::Type { name: _ } => (), + TraitItem::Type { name } => visitor.visit_trait_item_type(name), } } } @@ -375,13 +506,27 @@ impl UseTree { impl NoirStruct { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { - visitor.visit_noir_struct(self, span); + if visitor.visit_noir_struct(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for (_name, unresolved_type) in &self.fields { + unresolved_type.accept(visitor); + } } } impl NoirTypeAlias { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { - visitor.visit_noir_type_alias(self, span); + if visitor.visit_noir_type_alias(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.typ.accept(visitor); } } @@ -458,7 +603,9 @@ impl Expression { } } ExpressionKind::Variable(path) => { - visitor.visit_variable(path, self.span); + if visitor.visit_variable(path, self.span) { + path.accept(visitor); + } } ExpressionKind::AsTraitPath(as_trait_path) => { as_trait_path.accept(self.span, visitor); @@ -566,6 +713,8 @@ impl ConstructorExpression { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.type_name.accept(visitor); + for (_field_name, expression) in &self.fields { expression.accept(visitor); } @@ -633,6 +782,10 @@ impl Lambda { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + for (_, unresolved_type) in &self.parameters { + unresolved_type.accept(visitor); + } + self.body.accept(visitor); } } @@ -703,6 +856,7 @@ impl LetStatement { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.r#type.accept(visitor); self.expression.accept(visitor); } } @@ -790,7 +944,127 @@ impl ForRange { impl AsTraitPath { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { - visitor.visit_as_trait_path(self, span); + if visitor.visit_as_trait_path(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.trait_path.accept(visitor); + self.trait_generics.accept(visitor); + } +} + +impl UnresolvedType { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_unresolved_type(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.typ { + UnresolvedTypeData::Array(unresolved_type_expression, unresolved_type) => { + if visitor.visit_array_type(unresolved_type_expression, unresolved_type, self.span) + { + unresolved_type.accept(visitor); + } + } + UnresolvedTypeData::Slice(unresolved_type) => { + if visitor.visit_slice_type(unresolved_type, self.span) { + unresolved_type.accept(visitor); + } + } + UnresolvedTypeData::Parenthesized(unresolved_type) => { + if visitor.visit_parenthesized_type(unresolved_type, self.span) { + unresolved_type.accept(visitor); + } + } + UnresolvedTypeData::Named(path, generic_type_args, _) => { + if visitor.visit_named_type(path, generic_type_args, self.span) { + path.accept(visitor); + generic_type_args.accept(visitor); + } + } + UnresolvedTypeData::TraitAsType(path, generic_type_args) => { + if visitor.visit_trait_as_type(path, generic_type_args, self.span) { + path.accept(visitor); + generic_type_args.accept(visitor); + } + } + UnresolvedTypeData::MutableReference(unresolved_type) => { + if visitor.visit_mutable_reference_type(unresolved_type, self.span) { + unresolved_type.accept(visitor); + } + } + UnresolvedTypeData::Tuple(unresolved_types) => { + if visitor.visit_tuple_type(unresolved_types, self.span) { + visit_unresolved_types(unresolved_types, visitor); + } + } + UnresolvedTypeData::Function(args, ret, env, unconstrained) => { + if visitor.visit_function_type(args, ret, env, *unconstrained, self.span) { + visit_unresolved_types(args, visitor); + ret.accept(visitor); + env.accept(visitor); + } + } + UnresolvedTypeData::AsTraitPath(as_trait_path) => { + if visitor.visit_as_trait_path_type(as_trait_path, self.span) { + as_trait_path.accept(self.span, visitor); + } + } + UnresolvedTypeData::Expression(_) + | UnresolvedTypeData::FormatString(_, _) + | UnresolvedTypeData::String(_) + | UnresolvedTypeData::Unspecified + | UnresolvedTypeData::Quoted(_) + | UnresolvedTypeData::FieldElement + | UnresolvedTypeData::Integer(_, _) + | UnresolvedTypeData::Bool + | UnresolvedTypeData::Unit + | UnresolvedTypeData::Resolved(_) + | UnresolvedTypeData::Interned(_) + | UnresolvedTypeData::Error => (), + } + } +} + +impl Path { + pub fn accept(&self, visitor: &mut impl Visitor) { + visitor.visit_path(self); + } +} + +impl GenericTypeArgs { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_generic_type_args(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + visit_unresolved_types(&self.ordered_args, visitor); + for (_name, typ) in &self.named_args { + typ.accept(visitor); + } + } +} + +impl FunctionReturnType { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_function_return_type(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + FunctionReturnType::Default(_) => (), + FunctionReturnType::Ty(unresolved_type) => { + unresolved_type.accept(visitor); + } + } } } @@ -799,3 +1073,9 @@ fn visit_expressions(expressions: &[Expression], visitor: &mut impl Visitor) { expression.accept(visitor); } } + +fn visit_unresolved_types(unresolved_type: &[UnresolvedType], visitor: &mut impl Visitor) { + for unresolved_type in unresolved_type { + unresolved_type.accept(visitor); + } +} diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index f339ed19622..23dacefafce 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -16,11 +16,10 @@ use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ AsTraitPath, BlockExpression, CallExpression, ConstructorExpression, Expression, - ExpressionKind, ForLoopStatement, Ident, IfExpression, ItemVisibility, LValue, Lambda, - LetStatement, MemberAccessExpression, MethodCallExpression, NoirFunction, NoirStruct, - NoirTraitImpl, Path, PathKind, PathSegment, Pattern, Statement, StatementKind, TraitItem, - TypeImpl, UnresolvedGeneric, UnresolvedGenerics, UnresolvedType, UnresolvedTypeData, - UseTree, UseTreeKind, + ExpressionKind, ForLoopStatement, GenericTypeArgs, Ident, IfExpression, ItemVisibility, + Lambda, LetStatement, MemberAccessExpression, MethodCallExpression, NoirFunction, + NoirStruct, NoirTraitImpl, Path, PathKind, PathSegment, Pattern, Statement, TypeImpl, + UnresolvedGeneric, UnresolvedGenerics, UnresolvedType, UseTree, UseTreeKind, Visitor, }, graph::{CrateId, Dependency}, hir::{ @@ -33,7 +32,7 @@ use noirc_frontend::{ hir_def::traits::Trait, macros_api::{ModuleDefId, NodeInterner}, node_interner::ReferenceId, - parser::{Item, ItemKind}, + parser::{Item, ItemKind, ParsedSubModule}, ParsedModule, StructType, Type, }; use sort_text::underscore_sort_text; @@ -48,7 +47,6 @@ mod completion_items; mod kinds; mod sort_text; mod tests; -mod traversal; pub(crate) fn on_completion_request( state: &mut LspState, @@ -163,7 +161,7 @@ impl<'a> NodeFinder<'a> { } fn find(&mut self, parsed_module: &ParsedModule) -> Option { - self.find_in_parsed_module(parsed_module); + parsed_module.accept(self); if self.completion_items.is_empty() { None @@ -181,1164 +179,1062 @@ impl<'a> NodeFinder<'a> { } } - fn find_in_item(&mut self, item: &Item) { - if let ItemKind::Import(..) = &item.kind { - if let Some(lsp_location) = to_lsp_location(self.files, self.file, item.span) { - self.auto_import_line = (lsp_location.range.end.line + 1) as usize; - } - } - - if !self.includes_span(item.span) { + fn complete_constructor_field_name(&mut self, constructor_expression: &ConstructorExpression) { + let location = + Location::new(constructor_expression.type_name.last_ident().span(), self.file); + let Some(ReferenceId::Struct(struct_id)) = self.interner.find_referenced(location) else { return; - } - - match &item.kind { - ItemKind::Import(use_tree) => { - let mut prefixes = Vec::new(); - self.find_in_use_tree(use_tree, &mut prefixes); - } - ItemKind::Submodules(parsed_sub_module) => { - // Switch `self.module_id` to the submodule - let previous_module_id = self.module_id; - - let def_map = &self.def_maps[&self.module_id.krate]; - let Some(module_data) = def_map.modules().get(self.module_id.local_id.0) else { - return; - }; - if let Some(child_module) = module_data.children.get(&parsed_sub_module.name) { - self.module_id = - ModuleId { krate: self.module_id.krate, local_id: *child_module }; - } - - let old_auto_import_line = self.auto_import_line; - self.nesting += 1; - - if let Some(lsp_location) = to_lsp_location(self.files, self.file, item.span) { - self.auto_import_line = (lsp_location.range.start.line + 1) as usize; - } + }; - self.find_in_parsed_module(&parsed_sub_module.contents); + let struct_type = self.interner.get_struct(struct_id); + let struct_type = struct_type.borrow(); - // Restore the old module before continuing - self.module_id = previous_module_id; - self.nesting -= 1; - self.auto_import_line = old_auto_import_line; - } - ItemKind::Function(noir_function) => self.find_in_noir_function(noir_function), - ItemKind::TraitImpl(noir_trait_impl) => self.find_in_noir_trait_impl(noir_trait_impl), - ItemKind::Impl(type_impl) => self.find_in_type_impl(type_impl), - ItemKind::Global(let_statement) => self.find_in_let_statement(let_statement, false), - ItemKind::TypeAlias(noir_type_alias) => self.find_in_noir_type_alias(noir_type_alias), - ItemKind::Struct(noir_struct) => self.find_in_noir_struct(noir_struct), - ItemKind::Trait(noir_trait) => self.find_in_noir_trait(noir_trait), - ItemKind::ModuleDecl(_) => (), + // First get all of the struct's fields + let mut fields = HashMap::new(); + let fields_as_written = struct_type.get_fields_as_written(); + for (field, typ) in &fields_as_written { + fields.insert(field, typ); } - } - - fn find_in_noir_function(&mut self, noir_function: &NoirFunction) { - let old_type_parameters = self.type_parameters.clone(); - self.collect_type_parameters_in_generics(&noir_function.def.generics); - for param in &noir_function.def.parameters { - self.find_in_unresolved_type(¶m.typ); + // Remove the ones that already exists in the constructor + for (field, _) in &constructor_expression.fields { + fields.remove(&field.0.contents); } - self.find_in_function_return_type(&noir_function.def.return_type); - - self.local_variables.clear(); - for param in &noir_function.def.parameters { - self.collect_local_variables(¶m.pattern); + for (field, typ) in fields { + self.completion_items.push(struct_field_completion_item(field, typ)); } + } - self.find_in_block_expression(&noir_function.def.body); - - self.type_parameters = old_type_parameters; + fn find_in_path(&mut self, path: &Path, requested_items: RequestedItems) { + self.find_in_path_impl(path, requested_items, false); } - fn find_in_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl) { - self.find_in_path(&noir_trait_impl.trait_name, RequestedItems::OnlyTypes); - self.find_in_unresolved_type(&noir_trait_impl.object_type); + fn find_in_path_impl( + &mut self, + path: &Path, + requested_items: RequestedItems, + mut in_the_middle: bool, + ) { + if !self.includes_span(path.span) { + return; + } - self.type_parameters.clear(); - self.collect_type_parameters_in_generics(&noir_trait_impl.impl_generics); + let after_colons = self.byte == Some(b':'); - for item in &noir_trait_impl.items { - self.find_in_trait_impl_item(item); - } + let mut idents: Vec = Vec::new(); - self.type_parameters.clear(); - } + // Find in which ident we are in, and in which part of it + // (it could be that we are completting in the middle of an ident) + for segment in &path.segments { + let ident = &segment.ident; - fn find_in_type_impl(&mut self, type_impl: &TypeImpl) { - self.find_in_unresolved_type(&type_impl.object_type); + // Check if we are at the end of the ident + if self.byte_index == ident.span().end() as usize { + idents.push(ident.clone()); + break; + } - self.type_parameters.clear(); - self.collect_type_parameters_in_generics(&type_impl.generics); + // Check if we are in the middle of an ident + if self.includes_span(ident.span()) { + // If so, take the substring and push that as the list of idents + // we'll do autocompletion for + let offset = self.byte_index - ident.span().start() as usize; + let substring = ident.0.contents[0..offset].to_string(); + let ident = Ident::new( + substring, + Span::from(ident.span().start()..ident.span().start() + offset as u32), + ); + idents.push(ident); + in_the_middle = true; + break; + } - for (method, span) in &type_impl.methods { - self.find_in_noir_function(method); + idents.push(ident.clone()); - // Optimization: stop looking in functions past the completion cursor - if span.end() as usize > self.byte_index { + // Stop if the cursor is right after this ident and '::' + if after_colons && self.byte_index == ident.span().end() as usize + 2 { break; } } - self.type_parameters.clear(); - } + if idents.len() < path.segments.len() { + in_the_middle = true; + } - fn find_in_noir_struct(&mut self, noir_struct: &NoirStruct) { - self.type_parameters.clear(); - self.collect_type_parameters_in_generics(&noir_struct.generics); + let prefix; + let at_root; - for (_name, unresolved_type) in &noir_struct.fields { - self.find_in_unresolved_type(unresolved_type); + if after_colons { + prefix = String::new(); + at_root = false; + } else { + prefix = idents.pop().unwrap().to_string(); + at_root = idents.is_empty(); } - self.type_parameters.clear(); - } + let prefix = prefix.to_case(Case::Snake); - fn find_in_trait_item(&mut self, trait_item: &TraitItem) { - match trait_item { - TraitItem::Function { - name: _, - generics, - parameters, - return_type, - where_clause, - body, - } => { - let old_type_parameters = self.type_parameters.clone(); - self.collect_type_parameters_in_generics(generics); - - for (_name, unresolved_type) in parameters { - self.find_in_unresolved_type(unresolved_type); - } + let is_single_segment = !after_colons && idents.is_empty() && path.kind == PathKind::Plain; + let module_id; - self.find_in_function_return_type(return_type); + let module_completion_kind = if after_colons || !idents.is_empty() { + ModuleCompletionKind::DirectChildren + } else { + ModuleCompletionKind::AllVisibleItems + }; - for unresolved_trait_constraint in where_clause { - self.find_in_unresolved_type(&unresolved_trait_constraint.typ); - } + // When completing in the middle of an ident, we don't want to complete + // with function parameters because there might already be function parameters, + // and in the middle of a path it leads to code that won't compile + let function_completion_kind = if in_the_middle { + FunctionCompletionKind::Name + } else { + FunctionCompletionKind::NameAndParameters + }; - if let Some(body) = body { - self.local_variables.clear(); - for (name, _) in parameters { - self.local_variables.insert(name.to_string(), name.span()); - } - self.find_in_block_expression(body); - }; + if idents.is_empty() { + module_id = self.module_id; + } else { + let Some(module_def_id) = self.resolve_path(idents) else { + return; + }; - self.type_parameters = old_type_parameters; + match module_def_id { + ModuleDefId::ModuleId(id) => module_id = id, + ModuleDefId::TypeId(struct_id) => { + let struct_type = self.interner.get_struct(struct_id); + self.complete_type_methods( + &Type::Struct(struct_type, vec![]), + &prefix, + FunctionKind::Any, + function_completion_kind, + ); + return; + } + ModuleDefId::FunctionId(_) => { + // There's nothing inside a function + return; + } + ModuleDefId::TypeAliasId(type_alias_id) => { + let type_alias = self.interner.get_type_alias(type_alias_id); + let type_alias = type_alias.borrow(); + self.complete_type_methods( + &type_alias.typ, + &prefix, + FunctionKind::Any, + function_completion_kind, + ); + return; + } + ModuleDefId::TraitId(trait_id) => { + let trait_ = self.interner.get_trait(trait_id); + self.complete_trait_methods( + trait_, + &prefix, + FunctionKind::Any, + function_completion_kind, + ); + return; + } + ModuleDefId::GlobalId(_) => return, } - TraitItem::Constant { name: _, typ, default_value } => { - self.find_in_unresolved_type(typ); + } + + self.complete_in_module( + module_id, + &prefix, + path.kind, + at_root, + module_completion_kind, + function_completion_kind, + requested_items, + ); - if let Some(default_value) = default_value { - self.find_in_expression(default_value); + if is_single_segment { + match requested_items { + RequestedItems::AnyItems => { + self.local_variables_completion(&prefix); + self.builtin_functions_completion(&prefix, function_completion_kind); + self.builtin_values_completion(&prefix); + } + RequestedItems::OnlyTypes => { + self.builtin_types_completion(&prefix); + self.type_parameters_completion(&prefix); } } - TraitItem::Type { name: _ } => (), + self.complete_auto_imports(&prefix, requested_items, function_completion_kind); } } - pub(super) fn find_in_call_expression(&mut self, call_expression: &CallExpression) { - // Check if it's this case: - // - // foo::b>|<(...) - // - // In this case we want to suggest items in foo but if they are functions - // we don't want to insert arguments, because they are already there (even if - // they could be wrong) just because inserting them would lead to broken code. - if let ExpressionKind::Variable(path) = &call_expression.func.kind { - if self.includes_span(path.span) { - self.find_in_path_impl(path, RequestedItems::AnyItems, true); - return; + fn local_variables_completion(&mut self, prefix: &str) { + for (name, span) in &self.local_variables { + if name_matches(name, prefix) { + let location = Location::new(*span, self.file); + let description = if let Some(ReferenceId::Local(definition_id)) = + self.interner.reference_at_location(location) + { + let typ = self.interner.definition_type(definition_id); + Some(typ.to_string()) + } else { + None + }; + + self.completion_items.push(simple_completion_item( + name, + CompletionItemKind::VARIABLE, + description, + )); } } + } - // Check if it's this case: - // - // foo.>|<(...) - // - // "foo." is actually broken, but it's parsed as "foo", so this is seen - // as "foo(...)" but if we are at a dot right after "foo" it means it's - // the above case and we want to suggest methods of foo's type. - let after_dot = self.byte == Some(b'.'); - if after_dot && call_expression.func.span.end() as usize == self.byte_index - 1 { - let location = Location::new(call_expression.func.span, self.file); - if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); - let prefix = ""; - self.complete_type_fields_and_methods(&typ, prefix, FunctionCompletionKind::Name); - return; + fn type_parameters_completion(&mut self, prefix: &str) { + for name in &self.type_parameters { + if name_matches(name, prefix) { + self.completion_items.push(simple_completion_item( + name, + CompletionItemKind::TYPE_PARAMETER, + None, + )); } } - - self.find_in_expression(&call_expression.func); - self.find_in_expressions(&call_expression.arguments); } - pub(super) fn find_in_method_call_expression( - &mut self, - method_call_expression: &MethodCallExpression, - ) { - // Check if it's this case: - // - // foo.b>|<(...) - // - // In this case we want to suggest items in foo but if they are functions - // we don't want to insert arguments, because they are already there (even if - // they could be wrong) just because inserting them would lead to broken code. - if self.includes_span(method_call_expression.method_name.span()) { - let location = Location::new(method_call_expression.object.span, self.file); - if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); - let prefix = method_call_expression.method_name.to_string(); - let offset = - self.byte_index - method_call_expression.method_name.span().start() as usize; - let prefix = prefix[0..offset].to_string(); - self.complete_type_fields_and_methods(&typ, &prefix, FunctionCompletionKind::Name); - return; - } - } - - self.find_in_expression(&method_call_expression.object); - self.find_in_expressions(&method_call_expression.arguments); - } - - fn find_in_block_expression(&mut self, block_expression: &BlockExpression) { - let old_local_variables = self.local_variables.clone(); - for statement in &block_expression.statements { - self.find_in_statement(statement); - - // Optimization: stop looking in statements past the completion cursor - if statement.span.end() as usize > self.byte_index { - break; - } - } - self.local_variables = old_local_variables; - } - - fn find_in_statement(&mut self, statement: &Statement) { - match &statement.kind { - StatementKind::Let(let_statement) => { - self.find_in_let_statement(let_statement, true); - } - StatementKind::Constrain(constrain_statement) => { - self.find_in_constrain_statement(constrain_statement); - } - StatementKind::Expression(expression) => { - self.find_in_expression(expression); - } - StatementKind::Assign(assign_statement) => { - self.find_in_assign_statement(assign_statement); - } - StatementKind::For(for_loop_statement) => { - self.find_in_for_loop_statement(for_loop_statement); - } - StatementKind::Comptime(statement) => { - // When entering a comptime block, regular local variables shouldn't be offered anymore - let old_local_variables = self.local_variables.clone(); - self.local_variables.clear(); - - self.find_in_statement(statement); - - self.local_variables = old_local_variables; + fn find_in_use_tree(&mut self, use_tree: &UseTree, prefixes: &mut Vec) { + match &use_tree.kind { + UseTreeKind::Path(ident, alias) => { + prefixes.push(use_tree.prefix.clone()); + self.find_in_use_tree_path(prefixes, ident, alias); + prefixes.pop(); } - StatementKind::Semi(expression) => { - self.find_in_expression(expression); + UseTreeKind::List(use_trees) => { + prefixes.push(use_tree.prefix.clone()); + for use_tree in use_trees { + self.find_in_use_tree(use_tree, prefixes); + } + prefixes.pop(); } - StatementKind::Break - | StatementKind::Continue - | StatementKind::Interned(_) - | StatementKind::Error => (), } } - fn find_in_let_statement( + fn find_in_use_tree_path( &mut self, - let_statement: &LetStatement, - collect_local_variables: bool, + prefixes: &Vec, + ident: &Ident, + alias: &Option, ) { - self.find_in_unresolved_type(&let_statement.r#type); - self.find_in_expression(&let_statement.expression); - - if collect_local_variables { - self.collect_local_variables(&let_statement.pattern); + if let Some(_alias) = alias { + // Won't handle completion if there's an alias (for now) + return; } - } - fn find_in_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) { - let old_local_variables = self.local_variables.clone(); - let ident = &for_loop_statement.identifier; - self.local_variables.insert(ident.to_string(), ident.span()); + let after_colons = self.byte == Some(b':'); + let at_ident_end = self.byte_index == ident.span().end() as usize; + let at_ident_colons_end = + after_colons && self.byte_index - 2 == ident.span().end() as usize; - self.find_in_for_range(&for_loop_statement.range); - self.find_in_expression(&for_loop_statement.block); + if !(at_ident_end || at_ident_colons_end) { + return; + } - self.local_variables = old_local_variables; - } + let path_kind = prefixes[0].kind; - fn find_in_lvalue(&mut self, lvalue: &LValue) { - match lvalue { - LValue::Ident(ident) => { - if self.byte == Some(b'.') && ident.span().end() as usize == self.byte_index - 1 { - let location = Location::new(ident.span(), self.file); - if let Some(ReferenceId::Local(definition_id)) = - self.interner.find_referenced(location) - { - let typ = self.interner.definition_type(definition_id); - let prefix = ""; - self.complete_type_fields_and_methods( - &typ, - prefix, - FunctionCompletionKind::NameAndParameters, - ); - } - } - } - LValue::MemberAccess { object, field_name: _, span: _ } => self.find_in_lvalue(object), - LValue::Index { array, index, span: _ } => { - self.find_in_lvalue(array); - self.find_in_expression(index); + let mut segments: Vec = Vec::new(); + for prefix in prefixes { + for segment in &prefix.segments { + segments.push(segment.ident.clone()); } - LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), - LValue::Interned(..) => (), } - } - - fn find_in_expression(&mut self, expression: &Expression) { - match &expression.kind { - ExpressionKind::Literal(literal) => self.find_in_literal(literal), - ExpressionKind::Block(block_expression) => { - self.find_in_block_expression(block_expression); - } - ExpressionKind::Prefix(prefix_expression) => { - self.find_in_expression(&prefix_expression.rhs); - } - ExpressionKind::Index(index_expression) => { - self.find_in_index_expression(index_expression); - } - ExpressionKind::Call(call_expression) => { - self.find_in_call_expression(call_expression); - } - ExpressionKind::MethodCall(method_call_expression) => { - self.find_in_method_call_expression(method_call_expression); - } - ExpressionKind::Constructor(constructor_expression) => { - self.find_in_constructor_expression(constructor_expression); - } - ExpressionKind::MemberAccess(member_access_expression) => { - self.find_in_member_access_expression(member_access_expression); - } - ExpressionKind::Cast(cast_expression) => { - self.find_in_cast_expression(cast_expression); - } - ExpressionKind::Infix(infix_expression) => { - self.find_in_infix_expression(infix_expression); - } - ExpressionKind::If(if_expression) => { - self.find_in_if_expression(if_expression); - } - ExpressionKind::Variable(path) => { - self.find_in_path(path, RequestedItems::AnyItems); - } - ExpressionKind::Tuple(expressions) => { - self.find_in_expressions(expressions); - } - ExpressionKind::Lambda(lambda) => self.find_in_lambda(lambda), - ExpressionKind::Parenthesized(expression) => { - self.find_in_expression(expression); - } - ExpressionKind::Unquote(expression) => { - self.find_in_expression(expression); - } - ExpressionKind::Comptime(block_expression, _) => { - // When entering a comptime block, regular local variables shouldn't be offered anymore - let old_local_variables = self.local_variables.clone(); - self.local_variables.clear(); - self.find_in_block_expression(block_expression); + let module_completion_kind = ModuleCompletionKind::DirectChildren; + let function_completion_kind = FunctionCompletionKind::Name; + let requested_items = RequestedItems::AnyItems; - self.local_variables = old_local_variables; - } - ExpressionKind::Unsafe(block_expression, _) => { - self.find_in_block_expression(block_expression); - } - ExpressionKind::AsTraitPath(as_trait_path) => { - self.find_in_as_trait_path(as_trait_path); - } - ExpressionKind::Quote(_) - | ExpressionKind::Resolved(_) - | ExpressionKind::Interned(_) - | ExpressionKind::Error => (), - } + if after_colons { + // We are right after "::" + segments.push(ident.clone()); - // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. - // Here we check if there's a dot at the completion position, and if the expression - // ends right before the dot. If so, it means we want to complete the expression's type fields and methods. - // We only do this after visiting nested expressions, because in an expression like `foo & bar.` we want - // to complete for `bar`, not for `foo & bar`. - if self.completion_items.is_empty() - && self.byte == Some(b'.') - && expression.span.end() as usize == self.byte_index - 1 - { - let location = Location::new(expression.span, self.file); - if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); + if let Some(module_id) = self.resolve_module(segments) { let prefix = ""; - self.complete_type_fields_and_methods( - &typ, + let at_root = false; + self.complete_in_module( + module_id, prefix, - FunctionCompletionKind::NameAndParameters, + path_kind, + at_root, + module_completion_kind, + function_completion_kind, + requested_items, + ); + }; + } else { + // We are right after the last segment + let prefix = ident.to_string().to_case(Case::Snake); + if segments.is_empty() { + let at_root = true; + self.complete_in_module( + self.module_id, + &prefix, + path_kind, + at_root, + module_completion_kind, + function_completion_kind, + requested_items, + ); + } else if let Some(module_id) = self.resolve_module(segments) { + let at_root = false; + self.complete_in_module( + module_id, + &prefix, + path_kind, + at_root, + module_completion_kind, + function_completion_kind, + requested_items, ); } } } - fn find_in_constructor_expression(&mut self, constructor_expression: &ConstructorExpression) { - self.find_in_path(&constructor_expression.type_name, RequestedItems::OnlyTypes); - - // Check if we need to autocomplete the field name - if constructor_expression - .fields - .iter() - .any(|(field_name, _)| field_name.span().end() as usize == self.byte_index) - { - self.complete_constructor_field_name(constructor_expression); - return; + fn collect_local_variables(&mut self, pattern: &Pattern) { + match pattern { + Pattern::Identifier(ident) => { + self.local_variables.insert(ident.to_string(), ident.span()); + } + Pattern::Mutable(pattern, _, _) => self.collect_local_variables(pattern), + Pattern::Tuple(patterns, _) => { + for pattern in patterns { + self.collect_local_variables(pattern); + } + } + Pattern::Struct(_, patterns, _) => { + for (_, pattern) in patterns { + self.collect_local_variables(pattern); + } + } } + } - for (_field_name, expression) in &constructor_expression.fields { - self.find_in_expression(expression); + fn collect_type_parameters_in_generics(&mut self, generics: &UnresolvedGenerics) { + for generic in generics { + self.collect_type_parameters_in_generic(generic); } } - fn complete_constructor_field_name(&mut self, constructor_expression: &ConstructorExpression) { - let location = - Location::new(constructor_expression.type_name.last_ident().span(), self.file); - let Some(ReferenceId::Struct(struct_id)) = self.interner.find_referenced(location) else { - return; + fn collect_type_parameters_in_generic(&mut self, generic: &UnresolvedGeneric) { + match generic { + UnresolvedGeneric::Variable(ident) => { + self.type_parameters.insert(ident.to_string()); + } + UnresolvedGeneric::Numeric { ident, typ: _ } => { + self.type_parameters.insert(ident.to_string()); + } + UnresolvedGeneric::Resolved(..) => (), }; - - let struct_type = self.interner.get_struct(struct_id); - let struct_type = struct_type.borrow(); - - // First get all of the struct's fields - let mut fields = HashMap::new(); - let fields_as_written = struct_type.get_fields_as_written(); - for (field, typ) in &fields_as_written { - fields.insert(field, typ); - } - - // Remove the ones that already exists in the constructor - for (field, _) in &constructor_expression.fields { - fields.remove(&field.0.contents); - } - - for (field, typ) in fields { - self.completion_items.push(struct_field_completion_item(field, typ)); - } } - fn find_in_member_access_expression( + fn complete_type_fields_and_methods( &mut self, - member_access_expression: &MemberAccessExpression, + typ: &Type, + prefix: &str, + function_completion_kind: FunctionCompletionKind, ) { - let ident = &member_access_expression.rhs; - - if self.byte_index == ident.span().end() as usize { - // Assuming member_access_expression is of the form `foo.bar`, we are right after `bar` - let location = Location::new(member_access_expression.lhs.span, self.file); - if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); - let prefix = ident.to_string().to_case(Case::Snake); - self.complete_type_fields_and_methods( - &typ, - &prefix, - FunctionCompletionKind::NameAndParameters, + match typ { + Type::Struct(struct_type, generics) => { + self.complete_struct_fields(&struct_type.borrow(), generics, prefix); + } + Type::MutableReference(typ) => { + return self.complete_type_fields_and_methods( + typ, + prefix, + function_completion_kind, ); - return; } + Type::Alias(type_alias, _) => { + let type_alias = type_alias.borrow(); + return self.complete_type_fields_and_methods( + &type_alias.typ, + prefix, + function_completion_kind, + ); + } + Type::Tuple(types) => { + self.complete_tuple_fields(types); + } + Type::FieldElement + | Type::Array(_, _) + | Type::Slice(_) + | Type::Integer(_, _) + | Type::Bool + | Type::String(_) + | Type::FmtString(_, _) + | Type::Unit + | Type::TypeVariable(_, _) + | Type::TraitAsType(_, _, _) + | Type::NamedGeneric(_, _, _) + | Type::Function(..) + | Type::Forall(_, _) + | Type::Constant(_) + | Type::Quoted(_) + | Type::InfixExpr(_, _, _) + | Type::Error => (), } - - self.find_in_expression(&member_access_expression.lhs); - } - - fn find_in_if_expression(&mut self, if_expression: &IfExpression) { - self.find_in_expression(&if_expression.condition); - - let old_local_variables = self.local_variables.clone(); - self.find_in_expression(&if_expression.consequence); - self.local_variables = old_local_variables; - - if let Some(alternative) = &if_expression.alternative { - let old_local_variables = self.local_variables.clone(); - self.find_in_expression(alternative); - self.local_variables = old_local_variables; - } + + self.complete_type_methods( + typ, + prefix, + FunctionKind::SelfType(typ), + function_completion_kind, + ); } - fn find_in_lambda(&mut self, lambda: &Lambda) { - for (_, unresolved_type) in &lambda.parameters { - self.find_in_unresolved_type(unresolved_type); - } + fn complete_type_methods( + &mut self, + typ: &Type, + prefix: &str, + function_kind: FunctionKind, + function_completion_kind: FunctionCompletionKind, + ) { + let Some(methods_by_name) = self.interner.get_type_methods(typ) else { + return; + }; - let old_local_variables = self.local_variables.clone(); - for (pattern, _) in &lambda.parameters { - self.collect_local_variables(pattern); + for (name, methods) in methods_by_name { + for func_id in methods.iter() { + if name_matches(name, prefix) { + if let Some(completion_item) = self.function_completion_item( + func_id, + function_completion_kind, + function_kind, + ) { + self.completion_items.push(completion_item); + self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(func_id)); + } + } + } } - - self.find_in_expression(&lambda.body); - - self.local_variables = old_local_variables; } - fn find_in_as_trait_path(&mut self, as_trait_path: &AsTraitPath) { - self.find_in_path(&as_trait_path.trait_path, RequestedItems::OnlyTypes); - } - - fn find_in_unresolved_type(&mut self, unresolved_type: &UnresolvedType) { - if !self.includes_span(unresolved_type.span) { - return; + fn complete_trait_methods( + &mut self, + trait_: &Trait, + prefix: &str, + function_kind: FunctionKind, + function_completion_kind: FunctionCompletionKind, + ) { + for (name, func_id) in &trait_.method_ids { + if name_matches(name, prefix) { + if let Some(completion_item) = + self.function_completion_item(*func_id, function_completion_kind, function_kind) + { + self.completion_items.push(completion_item); + self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(*func_id)); + } + } } + } - match &unresolved_type.typ { - UnresolvedTypeData::Array(_, unresolved_type) => { - self.find_in_unresolved_type(unresolved_type); - } - UnresolvedTypeData::Slice(unresolved_type) => { - self.find_in_unresolved_type(unresolved_type); - } - UnresolvedTypeData::Parenthesized(unresolved_type) => { - self.find_in_unresolved_type(unresolved_type); - } - UnresolvedTypeData::Named(path, unresolved_types, _) => { - self.find_in_path(path, RequestedItems::OnlyTypes); - self.find_in_type_args(unresolved_types); - } - UnresolvedTypeData::TraitAsType(path, unresolved_types) => { - self.find_in_path(path, RequestedItems::OnlyTypes); - self.find_in_type_args(unresolved_types); - } - UnresolvedTypeData::MutableReference(unresolved_type) => { - self.find_in_unresolved_type(unresolved_type); - } - UnresolvedTypeData::Tuple(unresolved_types) => { - self.find_in_unresolved_types(unresolved_types); - } - UnresolvedTypeData::Function(args, ret, env, _) => { - self.find_in_unresolved_types(args); - self.find_in_unresolved_type(ret); - self.find_in_unresolved_type(env); - } - UnresolvedTypeData::AsTraitPath(as_trait_path) => { - self.find_in_as_trait_path(as_trait_path); + fn complete_struct_fields( + &mut self, + struct_type: &StructType, + generics: &[Type], + prefix: &str, + ) { + for (name, typ) in &struct_type.get_fields(generics) { + if name_matches(name, prefix) { + self.completion_items.push(struct_field_completion_item(name, typ)); } - UnresolvedTypeData::Expression(_) - | UnresolvedTypeData::FormatString(_, _) - | UnresolvedTypeData::String(_) - | UnresolvedTypeData::Unspecified - | UnresolvedTypeData::Quoted(_) - | UnresolvedTypeData::FieldElement - | UnresolvedTypeData::Integer(_, _) - | UnresolvedTypeData::Bool - | UnresolvedTypeData::Unit - | UnresolvedTypeData::Resolved(_) - | UnresolvedTypeData::Interned(_) - | UnresolvedTypeData::Error => (), } } - fn find_in_path(&mut self, path: &Path, requested_items: RequestedItems) { - self.find_in_path_impl(path, requested_items, false); + fn complete_tuple_fields(&mut self, types: &[Type]) { + for (index, typ) in types.iter().enumerate() { + self.completion_items.push(field_completion_item(&index.to_string(), typ.to_string())); + } } - fn find_in_path_impl( + #[allow(clippy::too_many_arguments)] + fn complete_in_module( &mut self, - path: &Path, + module_id: ModuleId, + prefix: &str, + path_kind: PathKind, + at_root: bool, + module_completion_kind: ModuleCompletionKind, + function_completion_kind: FunctionCompletionKind, requested_items: RequestedItems, - mut in_the_middle: bool, ) { - if !self.includes_span(path.span) { + let def_map = &self.def_maps[&module_id.krate]; + let Some(mut module_data) = def_map.modules().get(module_id.local_id.0) else { return; + }; + + if at_root { + match path_kind { + PathKind::Crate => { + let Some(root_module_data) = def_map.modules().get(def_map.root().0) else { + return; + }; + module_data = root_module_data; + } + PathKind::Super => { + let Some(parent) = module_data.parent else { + return; + }; + let Some(parent_module_data) = def_map.modules().get(parent.0) else { + return; + }; + module_data = parent_module_data; + } + PathKind::Dep => (), + PathKind::Plain => (), + } } - let after_colons = self.byte == Some(b':'); + let function_kind = FunctionKind::Any; - let mut idents: Vec = Vec::new(); + let items = match module_completion_kind { + ModuleCompletionKind::DirectChildren => module_data.definitions(), + ModuleCompletionKind::AllVisibleItems => module_data.scope(), + }; - // Find in which ident we are in, and in which part of it - // (it could be that we are completting in the middle of an ident) - for segment in &path.segments { - let ident = &segment.ident; + for ident in items.names() { + let name = &ident.0.contents; - // Check if we are at the end of the ident - if self.byte_index == ident.span().end() as usize { - idents.push(ident.clone()); - break; - } + if name_matches(name, prefix) { + let per_ns = module_data.find_name(ident); + if let Some((module_def_id, visibility, _)) = per_ns.types { + if is_visible(module_id, self.module_id, visibility, self.def_maps) { + if let Some(completion_item) = self.module_def_id_completion_item( + module_def_id, + name.clone(), + function_completion_kind, + function_kind, + requested_items, + ) { + self.completion_items.push(completion_item); + self.suggested_module_def_ids.insert(module_def_id); + } + } + } - // Check if we are in the middle of an ident - if self.includes_span(ident.span()) { - // If so, take the substring and push that as the list of idents - // we'll do autocompletion for - let offset = self.byte_index - ident.span().start() as usize; - let substring = ident.0.contents[0..offset].to_string(); - let ident = Ident::new( - substring, - Span::from(ident.span().start()..ident.span().start() + offset as u32), - ); - idents.push(ident); - in_the_middle = true; - break; + if let Some((module_def_id, visibility, _)) = per_ns.values { + if is_visible(module_id, self.module_id, visibility, self.def_maps) { + if let Some(completion_item) = self.module_def_id_completion_item( + module_def_id, + name.clone(), + function_completion_kind, + function_kind, + requested_items, + ) { + self.completion_items.push(completion_item); + self.suggested_module_def_ids.insert(module_def_id); + } + } + } } + } - idents.push(ident.clone()); + if at_root && path_kind == PathKind::Plain { + for dependency in self.dependencies { + let dependency_name = dependency.as_name(); + if name_matches(&dependency_name, prefix) { + self.completion_items.push(crate_completion_item(dependency_name)); + } + } - // Stop if the cursor is right after this ident and '::' - if after_colons && self.byte_index == ident.span().end() as usize + 2 { - break; + if name_matches("crate::", prefix) { + self.completion_items.push(simple_completion_item( + "crate::", + CompletionItemKind::KEYWORD, + None, + )); } - } - if idents.len() < path.segments.len() { - in_the_middle = true; + if module_data.parent.is_some() && name_matches("super::", prefix) { + self.completion_items.push(simple_completion_item( + "super::", + CompletionItemKind::KEYWORD, + None, + )); + } } + } - let prefix; - let at_root; - - if after_colons { - prefix = String::new(); - at_root = false; + fn resolve_module(&self, segments: Vec) -> Option { + if let Some(ModuleDefId::ModuleId(module_id)) = self.resolve_path(segments) { + Some(module_id) } else { - prefix = idents.pop().unwrap().to_string(); - at_root = idents.is_empty(); + None } + } - let prefix = prefix.to_case(Case::Snake); + fn resolve_path(&self, segments: Vec) -> Option { + let last_segment = segments.last().unwrap().clone(); - let is_single_segment = !after_colons && idents.is_empty() && path.kind == PathKind::Plain; - let module_id; + let path_segments = segments.into_iter().map(PathSegment::from).collect(); + let path = Path { segments: path_segments, kind: PathKind::Plain, span: Span::default() }; - let module_completion_kind = if after_colons || !idents.is_empty() { - ModuleCompletionKind::DirectChildren - } else { - ModuleCompletionKind::AllVisibleItems - }; + let path_resolver = StandardPathResolver::new(self.root_module_id); + if let Ok(path_resolution) = path_resolver.resolve(self.def_maps, path, &mut None) { + return Some(path_resolution.module_def_id); + } - // When completing in the middle of an ident, we don't want to complete - // with function parameters because there might already be function parameters, - // and in the middle of a path it leads to code that won't compile - let function_completion_kind = if in_the_middle { - FunctionCompletionKind::Name - } else { - FunctionCompletionKind::NameAndParameters - }; + // If we can't resolve a path trough lookup, let's see if the last segment is bound to a type + let location = Location::new(last_segment.span(), self.file); + if let Some(reference_id) = self.interner.find_referenced(location) { + if let Some(id) = module_def_id_from_reference_id(reference_id) { + return Some(id); + } + } - if idents.is_empty() { - module_id = self.module_id; - } else { - let Some(module_def_id) = self.resolve_path(idents) else { - return; - }; + None + } - match module_def_id { - ModuleDefId::ModuleId(id) => module_id = id, - ModuleDefId::TypeId(struct_id) => { - let struct_type = self.interner.get_struct(struct_id); - self.complete_type_methods( - &Type::Struct(struct_type, vec![]), - &prefix, - FunctionKind::Any, - function_completion_kind, - ); - return; - } - ModuleDefId::FunctionId(_) => { - // There's nothing inside a function - return; - } - ModuleDefId::TypeAliasId(type_alias_id) => { - let type_alias = self.interner.get_type_alias(type_alias_id); - let type_alias = type_alias.borrow(); - self.complete_type_methods( - &type_alias.typ, - &prefix, - FunctionKind::Any, - function_completion_kind, - ); - return; - } - ModuleDefId::TraitId(trait_id) => { - let trait_ = self.interner.get_trait(trait_id); - self.complete_trait_methods( - trait_, - &prefix, - FunctionKind::Any, - function_completion_kind, - ); - return; - } - ModuleDefId::GlobalId(_) => return, + fn includes_span(&self, span: Span) -> bool { + span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize + } +} + +impl<'a> Visitor for NodeFinder<'a> { + fn visit_item(&mut self, item: &Item) -> bool { + if let ItemKind::Import(..) = &item.kind { + if let Some(lsp_location) = to_lsp_location(self.files, self.file, item.span) { + self.auto_import_line = (lsp_location.range.end.line + 1) as usize; } } - self.complete_in_module( - module_id, - &prefix, - path.kind, - at_root, - module_completion_kind, - function_completion_kind, - requested_items, - ); + if !self.includes_span(item.span) { + return false; + } - if is_single_segment { - match requested_items { - RequestedItems::AnyItems => { - self.local_variables_completion(&prefix); - self.builtin_functions_completion(&prefix, function_completion_kind); - self.builtin_values_completion(&prefix); - } - RequestedItems::OnlyTypes => { - self.builtin_types_completion(&prefix); - self.type_parameters_completion(&prefix); - } - } - self.complete_auto_imports(&prefix, requested_items, function_completion_kind); + true + } + + fn visit_import(&mut self, use_tree: &UseTree) -> bool { + let mut prefixes = Vec::new(); + self.find_in_use_tree(use_tree, &mut prefixes); + false + } + + fn visit_parsed_submodule(&mut self, parsed_sub_module: &ParsedSubModule, span: Span) -> bool { + // Switch `self.module_id` to the submodule + let previous_module_id = self.module_id; + + let def_map = &self.def_maps[&self.module_id.krate]; + let Some(module_data) = def_map.modules().get(self.module_id.local_id.0) else { + return false; + }; + if let Some(child_module) = module_data.children.get(&parsed_sub_module.name) { + self.module_id = ModuleId { krate: self.module_id.krate, local_id: *child_module }; } + + let old_auto_import_line = self.auto_import_line; + self.nesting += 1; + + if let Some(lsp_location) = to_lsp_location(self.files, self.file, span) { + self.auto_import_line = (lsp_location.range.start.line + 1) as usize; + } + + parsed_sub_module.accept_children(self); + + // Restore the old module before continuing + self.module_id = previous_module_id; + self.nesting -= 1; + self.auto_import_line = old_auto_import_line; + + false } - fn local_variables_completion(&mut self, prefix: &str) { - for (name, span) in &self.local_variables { - if name_matches(name, prefix) { - let location = Location::new(*span, self.file); - let description = if let Some(ReferenceId::Local(definition_id)) = - self.interner.reference_at_location(location) - { - let typ = self.interner.definition_type(definition_id); - Some(typ.to_string()) - } else { - None - }; + fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Option) -> bool { + let old_type_parameters = self.type_parameters.clone(); + self.collect_type_parameters_in_generics(&noir_function.def.generics); - self.completion_items.push(simple_completion_item( - name, - CompletionItemKind::VARIABLE, - description, - )); - } + for param in &noir_function.def.parameters { + param.typ.accept(self); + } + + noir_function.def.return_type.accept(self); + + self.local_variables.clear(); + for param in &noir_function.def.parameters { + self.collect_local_variables(¶m.pattern); } + + noir_function.def.body.accept(span, self); + + self.type_parameters = old_type_parameters; + + false } - fn type_parameters_completion(&mut self, prefix: &str) { - for name in &self.type_parameters { - if name_matches(name, prefix) { - self.completion_items.push(simple_completion_item( - name, - CompletionItemKind::TYPE_PARAMETER, - None, - )); - } + fn visit_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl, _: Span) -> bool { + self.find_in_path(&noir_trait_impl.trait_name, RequestedItems::OnlyTypes); + noir_trait_impl.object_type.accept(self); + + self.type_parameters.clear(); + self.collect_type_parameters_in_generics(&noir_trait_impl.impl_generics); + + for item in &noir_trait_impl.items { + item.accept(self); } + + self.type_parameters.clear(); + + false } - fn find_in_use_tree(&mut self, use_tree: &UseTree, prefixes: &mut Vec) { - match &use_tree.kind { - UseTreeKind::Path(ident, alias) => { - prefixes.push(use_tree.prefix.clone()); - self.find_in_use_tree_path(prefixes, ident, alias); - prefixes.pop(); - } - UseTreeKind::List(use_trees) => { - prefixes.push(use_tree.prefix.clone()); - for use_tree in use_trees { - self.find_in_use_tree(use_tree, prefixes); - } - prefixes.pop(); + fn visit_type_impl(&mut self, type_impl: &TypeImpl, _: Span) -> bool { + type_impl.object_type.accept(self); + + self.type_parameters.clear(); + self.collect_type_parameters_in_generics(&type_impl.generics); + + for (method, span) in &type_impl.methods { + method.accept(Some(*span), self); + + // Optimization: stop looking in functions past the completion cursor + if span.end() as usize > self.byte_index { + break; } } + + self.type_parameters.clear(); + + false } - fn find_in_use_tree_path( - &mut self, - prefixes: &Vec, - ident: &Ident, - alias: &Option, - ) { - if let Some(_alias) = alias { - // Won't handle completion if there's an alias (for now) - return; + fn visit_noir_struct(&mut self, noir_struct: &NoirStruct, _: Span) -> bool { + self.type_parameters.clear(); + self.collect_type_parameters_in_generics(&noir_struct.generics); + + for (_name, unresolved_type) in &noir_struct.fields { + unresolved_type.accept(self); } - let after_colons = self.byte == Some(b':'); - let at_ident_end = self.byte_index == ident.span().end() as usize; - let at_ident_colons_end = - after_colons && self.byte_index - 2 == ident.span().end() as usize; + self.type_parameters.clear(); - if !(at_ident_end || at_ident_colons_end) { - return; + false + } + + fn visit_trait_item_function( + &mut self, + _name: &Ident, + generics: &UnresolvedGenerics, + parameters: &[(Ident, UnresolvedType)], + return_type: &noirc_frontend::ast::FunctionReturnType, + where_clause: &[noirc_frontend::ast::UnresolvedTraitConstraint], + body: &Option, + ) -> bool { + let old_type_parameters = self.type_parameters.clone(); + self.collect_type_parameters_in_generics(generics); + + for (_name, unresolved_type) in parameters { + unresolved_type.accept(self); } - let path_kind = prefixes[0].kind; + return_type.accept(self); - let mut segments: Vec = Vec::new(); - for prefix in prefixes { - for segment in &prefix.segments { - segments.push(segment.ident.clone()); - } + for unresolved_trait_constraint in where_clause { + unresolved_trait_constraint.typ.accept(self); } - let module_completion_kind = ModuleCompletionKind::DirectChildren; - let function_completion_kind = FunctionCompletionKind::Name; - let requested_items = RequestedItems::AnyItems; + if let Some(body) = body { + self.local_variables.clear(); + for (name, _) in parameters { + self.local_variables.insert(name.to_string(), name.span()); + } + body.accept(None, self); + }; - if after_colons { - // We are right after "::" - segments.push(ident.clone()); + self.type_parameters = old_type_parameters; - if let Some(module_id) = self.resolve_module(segments) { + false + } + + fn visit_call_expression(&mut self, call_expression: &CallExpression, _: Span) -> bool { + // + // foo::b>|<(...) + // + // In this case we want to suggest items in foo but if they are functions + // we don't want to insert arguments, because they are already there (even if + // they could be wrong) just because inserting them would lead to broken code. + if let ExpressionKind::Variable(path) = &call_expression.func.kind { + if self.includes_span(path.span) { + self.find_in_path_impl(path, RequestedItems::AnyItems, true); + return false; + } + } + + // Check if it's this case: + // + // foo.>|<(...) + // + // "foo." is actually broken, but it's parsed as "foo", so this is seen + // as "foo(...)" but if we are at a dot right after "foo" it means it's + // the above case and we want to suggest methods of foo's type. + let after_dot = self.byte == Some(b'.'); + if after_dot && call_expression.func.span.end() as usize == self.byte_index - 1 { + let location = Location::new(call_expression.func.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); let prefix = ""; - let at_root = false; - self.complete_in_module( - module_id, - prefix, - path_kind, - at_root, - module_completion_kind, - function_completion_kind, - requested_items, - ); - }; - } else { - // We are right after the last segment - let prefix = ident.to_string().to_case(Case::Snake); - if segments.is_empty() { - let at_root = true; - self.complete_in_module( - self.module_id, - &prefix, - path_kind, - at_root, - module_completion_kind, - function_completion_kind, - requested_items, - ); - } else if let Some(module_id) = self.resolve_module(segments) { - let at_root = false; - self.complete_in_module( - module_id, - &prefix, - path_kind, - at_root, - module_completion_kind, - function_completion_kind, - requested_items, - ); + self.complete_type_fields_and_methods(&typ, prefix, FunctionCompletionKind::Name); + return false; } } + + true } - fn collect_local_variables(&mut self, pattern: &Pattern) { - match pattern { - Pattern::Identifier(ident) => { - self.local_variables.insert(ident.to_string(), ident.span()); - } - Pattern::Mutable(pattern, _, _) => self.collect_local_variables(pattern), - Pattern::Tuple(patterns, _) => { - for pattern in patterns { - self.collect_local_variables(pattern); - } - } - Pattern::Struct(_, patterns, _) => { - for (_, pattern) in patterns { - self.collect_local_variables(pattern); - } + fn visit_method_call_expression( + &mut self, + method_call_expression: &MethodCallExpression, + _: Span, + ) -> bool { + // Check if it's this case: + // + // foo.b>|<(...) + // + // In this case we want to suggest items in foo but if they are functions + // we don't want to insert arguments, because they are already there (even if + // they could be wrong) just because inserting them would lead to broken code. + if self.includes_span(method_call_expression.method_name.span()) { + let location = Location::new(method_call_expression.object.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = method_call_expression.method_name.to_string(); + let offset = + self.byte_index - method_call_expression.method_name.span().start() as usize; + let prefix = prefix[0..offset].to_string(); + self.complete_type_fields_and_methods(&typ, &prefix, FunctionCompletionKind::Name); + return false; } } + + true } - fn collect_type_parameters_in_generics(&mut self, generics: &UnresolvedGenerics) { - for generic in generics { - self.collect_type_parameters_in_generic(generic); + fn visit_block_expression( + &mut self, + block_expression: &BlockExpression, + _: Option, + ) -> bool { + let old_local_variables = self.local_variables.clone(); + for statement in &block_expression.statements { + statement.accept(self); + + // Optimization: stop looking in statements past the completion cursor + if statement.span.end() as usize > self.byte_index { + break; + } } + self.local_variables = old_local_variables; + + false } - fn collect_type_parameters_in_generic(&mut self, generic: &UnresolvedGeneric) { - match generic { - UnresolvedGeneric::Variable(ident) => { - self.type_parameters.insert(ident.to_string()); - } - UnresolvedGeneric::Numeric { ident, typ: _ } => { - self.type_parameters.insert(ident.to_string()); - } - UnresolvedGeneric::Resolved(..) => (), - }; + fn visit_let_statement(&mut self, let_statement: &LetStatement) -> bool { + let_statement.accept_children(self); + self.collect_local_variables(&let_statement.pattern); + false } - fn complete_type_fields_and_methods( - &mut self, - typ: &Type, - prefix: &str, - function_completion_kind: FunctionCompletionKind, - ) { - match typ { - Type::Struct(struct_type, generics) => { - self.complete_struct_fields(&struct_type.borrow(), generics, prefix); - } - Type::MutableReference(typ) => { - return self.complete_type_fields_and_methods( - typ, - prefix, - function_completion_kind, - ); - } - Type::Alias(type_alias, _) => { - let type_alias = type_alias.borrow(); - return self.complete_type_fields_and_methods( - &type_alias.typ, - prefix, - function_completion_kind, - ); - } - Type::Tuple(types) => { - self.complete_tuple_fields(types); - } - Type::FieldElement - | Type::Array(_, _) - | Type::Slice(_) - | Type::Integer(_, _) - | Type::Bool - | Type::String(_) - | Type::FmtString(_, _) - | Type::Unit - | Type::TypeVariable(_, _) - | Type::TraitAsType(_, _, _) - | Type::NamedGeneric(_, _, _) - | Type::Function(..) - | Type::Forall(_, _) - | Type::Constant(_) - | Type::Quoted(_) - | Type::InfixExpr(_, _, _) - | Type::Error => (), + fn visit_global(&mut self, let_statement: &LetStatement, _: Span) -> bool { + let_statement.accept_children(self); + false + } + + fn visit_comptime_statement(&mut self, statement: &Statement) -> bool { + // When entering a comptime block, regular local variables shouldn't be offered anymore + let old_local_variables = self.local_variables.clone(); + self.local_variables.clear(); + + statement.accept(self); + + self.local_variables = old_local_variables; + + false + } + + fn visit_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) -> bool { + let old_local_variables = self.local_variables.clone(); + let ident = &for_loop_statement.identifier; + self.local_variables.insert(ident.to_string(), ident.span()); + + for_loop_statement.accept_children(self); + + self.local_variables = old_local_variables; + + false + } + + fn visit_lvalue_ident(&mut self, ident: &Ident) { + if self.byte == Some(b'.') && ident.span().end() as usize == self.byte_index - 1 { + let location = Location::new(ident.span(), self.file); + if let Some(ReferenceId::Local(definition_id)) = self.interner.find_referenced(location) + { + let typ = self.interner.definition_type(definition_id); + let prefix = ""; + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + ); + } } + } - self.complete_type_methods( - typ, - prefix, - FunctionKind::SelfType(typ), - function_completion_kind, - ); + fn visit_variable(&mut self, path: &Path, _: Span) -> bool { + self.find_in_path(path, RequestedItems::AnyItems); + false } - fn complete_type_methods( - &mut self, - typ: &Type, - prefix: &str, - function_kind: FunctionKind, - function_completion_kind: FunctionCompletionKind, - ) { - let Some(methods_by_name) = self.interner.get_type_methods(typ) else { - return; - }; + fn visit_expression(&mut self, expression: &Expression) -> bool { + expression.accept_children(self); - for (name, methods) in methods_by_name { - for func_id in methods.iter() { - if name_matches(name, prefix) { - if let Some(completion_item) = self.function_completion_item( - func_id, - function_completion_kind, - function_kind, - ) { - self.completion_items.push(completion_item); - self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(func_id)); - } - } + // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. + // Here we check if there's a dot at the completion position, and if the expression + // ends right before the dot. If so, it means we want to complete the expression's type fields and methods. + // We only do this after visiting nested expressions, because in an expression like `foo & bar.` we want + // to complete for `bar`, not for `foo & bar`. + if self.completion_items.is_empty() + && self.byte == Some(b'.') + && expression.span.end() as usize == self.byte_index - 1 + { + let location = Location::new(expression.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ""; + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + ); } } + + false } - fn complete_trait_methods( + fn visit_comptime_expression( &mut self, - trait_: &Trait, - prefix: &str, - function_kind: FunctionKind, - function_completion_kind: FunctionCompletionKind, - ) { - for (name, func_id) in &trait_.method_ids { - if name_matches(name, prefix) { - if let Some(completion_item) = - self.function_completion_item(*func_id, function_completion_kind, function_kind) - { - self.completion_items.push(completion_item); - self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(*func_id)); - } - } - } + block_expression: &BlockExpression, + span: Span, + ) -> bool { + // When entering a comptime block, regular local variables shouldn't be offered anymore + let old_local_variables = self.local_variables.clone(); + self.local_variables.clear(); + + block_expression.accept(Some(span), self); + + self.local_variables = old_local_variables; + + false } - fn complete_struct_fields( + fn visit_constructor_expression( &mut self, - struct_type: &StructType, - generics: &[Type], - prefix: &str, - ) { - for (name, typ) in &struct_type.get_fields(generics) { - if name_matches(name, prefix) { - self.completion_items.push(struct_field_completion_item(name, typ)); - } + constructor_expression: &ConstructorExpression, + _: Span, + ) -> bool { + self.find_in_path(&constructor_expression.type_name, RequestedItems::OnlyTypes); + + // Check if we need to autocomplete the field name + if constructor_expression + .fields + .iter() + .any(|(field_name, _)| field_name.span().end() as usize == self.byte_index) + { + self.complete_constructor_field_name(constructor_expression); + return false; } - } - fn complete_tuple_fields(&mut self, types: &[Type]) { - for (index, typ) in types.iter().enumerate() { - self.completion_items.push(field_completion_item(&index.to_string(), typ.to_string())); + for (_field_name, expression) in &constructor_expression.fields { + expression.accept(self); } + + false } - #[allow(clippy::too_many_arguments)] - fn complete_in_module( + fn visit_member_access_expression( &mut self, - module_id: ModuleId, - prefix: &str, - path_kind: PathKind, - at_root: bool, - module_completion_kind: ModuleCompletionKind, - function_completion_kind: FunctionCompletionKind, - requested_items: RequestedItems, - ) { - let def_map = &self.def_maps[&module_id.krate]; - let Some(mut module_data) = def_map.modules().get(module_id.local_id.0) else { - return; - }; + member_access_expression: &MemberAccessExpression, + _: Span, + ) -> bool { + let ident = &member_access_expression.rhs; - if at_root { - match path_kind { - PathKind::Crate => { - let Some(root_module_data) = def_map.modules().get(def_map.root().0) else { - return; - }; - module_data = root_module_data; - } - PathKind::Super => { - let Some(parent) = module_data.parent else { - return; - }; - let Some(parent_module_data) = def_map.modules().get(parent.0) else { - return; - }; - module_data = parent_module_data; - } - PathKind::Dep => (), - PathKind::Plain => (), + if self.byte_index == ident.span().end() as usize { + // Assuming member_access_expression is of the form `foo.bar`, we are right after `bar` + let location = Location::new(member_access_expression.lhs.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ident.to_string().to_case(Case::Snake); + self.complete_type_fields_and_methods( + &typ, + &prefix, + FunctionCompletionKind::NameAndParameters, + ); + return false; } } - let function_kind = FunctionKind::Any; - - let items = match module_completion_kind { - ModuleCompletionKind::DirectChildren => module_data.definitions(), - ModuleCompletionKind::AllVisibleItems => module_data.scope(), - }; + true + } - for ident in items.names() { - let name = &ident.0.contents; + fn visit_if_expression(&mut self, if_expression: &IfExpression, _: Span) -> bool { + if_expression.condition.accept(self); - if name_matches(name, prefix) { - let per_ns = module_data.find_name(ident); - if let Some((module_def_id, visibility, _)) = per_ns.types { - if is_visible(module_id, self.module_id, visibility, self.def_maps) { - if let Some(completion_item) = self.module_def_id_completion_item( - module_def_id, - name.clone(), - function_completion_kind, - function_kind, - requested_items, - ) { - self.completion_items.push(completion_item); - self.suggested_module_def_ids.insert(module_def_id); - } - } - } + let old_local_variables = self.local_variables.clone(); + if_expression.consequence.accept(self); + self.local_variables = old_local_variables; - if let Some((module_def_id, visibility, _)) = per_ns.values { - if is_visible(module_id, self.module_id, visibility, self.def_maps) { - if let Some(completion_item) = self.module_def_id_completion_item( - module_def_id, - name.clone(), - function_completion_kind, - function_kind, - requested_items, - ) { - self.completion_items.push(completion_item); - self.suggested_module_def_ids.insert(module_def_id); - } - } - } - } + if let Some(alternative) = &if_expression.alternative { + let old_local_variables = self.local_variables.clone(); + alternative.accept(self); + self.local_variables = old_local_variables; } - if at_root && path_kind == PathKind::Plain { - for dependency in self.dependencies { - let dependency_name = dependency.as_name(); - if name_matches(&dependency_name, prefix) { - self.completion_items.push(crate_completion_item(dependency_name)); - } - } - - if name_matches("crate::", prefix) { - self.completion_items.push(simple_completion_item( - "crate::", - CompletionItemKind::KEYWORD, - None, - )); - } + false + } - if module_data.parent.is_some() && name_matches("super::", prefix) { - self.completion_items.push(simple_completion_item( - "super::", - CompletionItemKind::KEYWORD, - None, - )); - } + fn visit_lambda(&mut self, lambda: &Lambda, _: Span) -> bool { + for (_, unresolved_type) in &lambda.parameters { + unresolved_type.accept(self); } - } - fn resolve_module(&self, segments: Vec) -> Option { - if let Some(ModuleDefId::ModuleId(module_id)) = self.resolve_path(segments) { - Some(module_id) - } else { - None + let old_local_variables = self.local_variables.clone(); + for (pattern, _) in &lambda.parameters { + self.collect_local_variables(pattern); } - } - fn resolve_path(&self, segments: Vec) -> Option { - let last_segment = segments.last().unwrap().clone(); + lambda.body.accept(self); - let path_segments = segments.into_iter().map(PathSegment::from).collect(); - let path = Path { segments: path_segments, kind: PathKind::Plain, span: Span::default() }; + self.local_variables = old_local_variables; - let path_resolver = StandardPathResolver::new(self.root_module_id); - if let Ok(path_resolution) = path_resolver.resolve(self.def_maps, path, &mut None) { - return Some(path_resolution.module_def_id); - } + false + } - // If we can't resolve a path trough lookup, let's see if the last segment is bound to a type - let location = Location::new(last_segment.span(), self.file); - if let Some(reference_id) = self.interner.find_referenced(location) { - if let Some(id) = module_def_id_from_reference_id(reference_id) { - return Some(id); - } - } + fn visit_as_trait_path(&mut self, as_trait_path: &AsTraitPath, _: Span) -> bool { + self.find_in_path(&as_trait_path.trait_path, RequestedItems::OnlyTypes); - None + false } - fn includes_span(&self, span: Span) -> bool { - span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize + fn visit_unresolved_type(&mut self, unresolved_type: &UnresolvedType) -> bool { + self.includes_span(unresolved_type.span) + } + + fn visit_named_type( + &mut self, + path: &Path, + unresolved_types: &GenericTypeArgs, + _: Span, + ) -> bool { + self.find_in_path(path, RequestedItems::OnlyTypes); + unresolved_types.accept(self); + false } } diff --git a/tooling/lsp/src/requests/completion/traversal.rs b/tooling/lsp/src/requests/completion/traversal.rs deleted file mode 100644 index e3bd8ffadf7..00000000000 --- a/tooling/lsp/src/requests/completion/traversal.rs +++ /dev/null @@ -1,126 +0,0 @@ -/// This file includes the completion logic that's just about -/// traversing the AST without any additional logic. -use noirc_frontend::{ - ast::{ - ArrayLiteral, AssignStatement, CastExpression, ConstrainStatement, Expression, ForRange, - FunctionReturnType, GenericTypeArgs, IndexExpression, InfixExpression, Literal, NoirTrait, - NoirTypeAlias, TraitImplItem, UnresolvedType, - }, - ParsedModule, -}; - -use super::NodeFinder; - -impl<'a> NodeFinder<'a> { - pub(super) fn find_in_parsed_module(&mut self, parsed_module: &ParsedModule) { - for item in &parsed_module.items { - self.find_in_item(item); - } - } - - pub(super) fn find_in_trait_impl_item(&mut self, item: &TraitImplItem) { - match item { - TraitImplItem::Function(noir_function) => self.find_in_noir_function(noir_function), - TraitImplItem::Constant(_, _, _) => (), - TraitImplItem::Type { .. } => (), - } - } - - pub(super) fn find_in_noir_trait(&mut self, noir_trait: &NoirTrait) { - for item in &noir_trait.items { - self.find_in_trait_item(item); - } - } - - pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { - self.find_in_expression(&constrain_statement.0); - - if let Some(exp) = &constrain_statement.1 { - self.find_in_expression(exp); - } - } - - pub(super) fn find_in_assign_statement(&mut self, assign_statement: &AssignStatement) { - self.find_in_lvalue(&assign_statement.lvalue); - self.find_in_expression(&assign_statement.expression); - } - - pub(super) fn find_in_for_range(&mut self, for_range: &ForRange) { - match for_range { - ForRange::Range(start, end) => { - self.find_in_expression(start); - self.find_in_expression(end); - } - ForRange::Array(expression) => self.find_in_expression(expression), - } - } - - pub(super) fn find_in_expressions(&mut self, expressions: &[Expression]) { - for expression in expressions { - self.find_in_expression(expression); - } - } - - pub(super) fn find_in_literal(&mut self, literal: &Literal) { - match literal { - Literal::Array(array_literal) => self.find_in_array_literal(array_literal), - Literal::Slice(array_literal) => self.find_in_array_literal(array_literal), - Literal::Bool(_) - | Literal::Integer(_, _) - | Literal::Str(_) - | Literal::RawStr(_, _) - | Literal::FmtStr(_) - | Literal::Unit => (), - } - } - - pub(super) fn find_in_array_literal(&mut self, array_literal: &ArrayLiteral) { - match array_literal { - ArrayLiteral::Standard(expressions) => self.find_in_expressions(expressions), - ArrayLiteral::Repeated { repeated_element, length } => { - self.find_in_expression(repeated_element); - self.find_in_expression(length); - } - } - } - - pub(super) fn find_in_index_expression(&mut self, index_expression: &IndexExpression) { - self.find_in_expression(&index_expression.collection); - self.find_in_expression(&index_expression.index); - } - - pub(super) fn find_in_cast_expression(&mut self, cast_expression: &CastExpression) { - self.find_in_expression(&cast_expression.lhs); - } - - pub(super) fn find_in_infix_expression(&mut self, infix_expression: &InfixExpression) { - self.find_in_expression(&infix_expression.lhs); - self.find_in_expression(&infix_expression.rhs); - } - - pub(super) fn find_in_unresolved_types(&mut self, unresolved_type: &[UnresolvedType]) { - for unresolved_type in unresolved_type { - self.find_in_unresolved_type(unresolved_type); - } - } - - pub(super) fn find_in_type_args(&mut self, generics: &GenericTypeArgs) { - self.find_in_unresolved_types(&generics.ordered_args); - for (_name, typ) in &generics.named_args { - self.find_in_unresolved_type(typ); - } - } - - pub(super) fn find_in_function_return_type(&mut self, return_type: &FunctionReturnType) { - match return_type { - noirc_frontend::ast::FunctionReturnType::Default(_) => (), - noirc_frontend::ast::FunctionReturnType::Ty(unresolved_type) => { - self.find_in_unresolved_type(unresolved_type); - } - } - } - - pub(super) fn find_in_noir_type_alias(&mut self, noir_type_alias: &NoirTypeAlias) { - self.find_in_unresolved_type(&noir_type_alias.typ); - } -} From dac5bf08b6c1eef5adbfbb65cca0d6024b230d09 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 10:40:45 -0300 Subject: [PATCH 3/8] Complete all visitor methods --- compiler/noirc_frontend/src/ast/visitor.rs | 249 +++++++++++++++++---- 1 file changed, 206 insertions(+), 43 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index 2ebed378620..1ea4967bb3e 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -1,3 +1,4 @@ +use acvm::FieldElement; use noirc_errors::Span; use crate::{ @@ -10,13 +11,18 @@ use crate::{ PrefixExpression, Statement, StatementKind, TraitImplItem, TraitItem, TypeImpl, UseTree, UseTreeKind, }, + node_interner::{ + ExprId, InternedExpressionKind, InternedStatementKind, InternedUnresolvedTypeData, + QuotedTypeId, + }, parser::{Item, ItemKind, ParsedSubModule}, - ParsedModule, + token::Tokens, + ParsedModule, QuotedType, }; use super::{ - FunctionReturnType, GenericTypeArgs, UnresolvedGenerics, UnresolvedTraitConstraint, - UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + FunctionReturnType, GenericTypeArgs, IntegerBitSize, Signedness, UnresolvedGenerics, + UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; /// Implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for Noir's AST. @@ -54,6 +60,23 @@ pub trait Visitor { true } + fn visit_trait_impl_item_function(&mut self, _: &NoirFunction) -> bool { + true + } + + fn visit_trait_impl_item_constant( + &mut self, + _name: &Ident, + _typ: &UnresolvedType, + _expression: &Expression, + ) -> bool { + true + } + + fn visit_trait_impl_item_type(&mut self, _name: &Ident, _alias: &UnresolvedType) -> bool { + true + } + fn visit_noir_trait(&mut self, _: &NoirTrait, _: Span) -> bool { true } @@ -113,6 +136,26 @@ pub trait Visitor { true } + fn visit_literal_array(&mut self, _: &ArrayLiteral) -> bool { + true + } + + fn visit_literal_slice(&mut self, _: &ArrayLiteral) -> bool { + true + } + + fn visit_literal_bool(&mut self, _: bool) {} + + fn visit_literal_integer(&mut self, _value: FieldElement, _negative: bool) {} + + fn visit_literal_str(&mut self, _: &str) {} + + fn visit_literal_raw_str(&mut self, _: &str, _: u8) {} + + fn visit_literal_fmt_str(&mut self, _: &str) {} + + fn visit_literal_unit(&mut self) {} + fn visit_block_expression(&mut self, _: &BlockExpression, _: Option) -> bool { true } @@ -177,6 +220,14 @@ pub trait Visitor { true } + fn visit_quote(&mut self, _: &Tokens) {} + + fn visit_resolved_expression(&mut self, _expr_id: ExprId) {} + + fn visit_interned_expression(&mut self, _id: InternedExpressionKind) {} + + fn visit_error_expression(&mut self) {} + fn visit_lambda(&mut self, _: &Lambda, _: Span) -> bool { true } @@ -185,6 +236,18 @@ pub trait Visitor { true } + fn visit_array_literal_standard(&mut self, _: &[Expression]) -> bool { + true + } + + fn visit_array_literal_repeated( + &mut self, + _repeated_element: &Expression, + _length: &Expression, + ) -> bool { + true + } + fn visit_statement(&mut self, _: &Statement) -> bool { true } @@ -217,12 +280,39 @@ pub trait Visitor { true } + fn visit_break(&mut self) {} + + fn visit_continue(&mut self) {} + + fn visit_interned_statement(&mut self, _: InternedStatementKind) {} + + fn visit_error_statement(&mut self) {} + fn visit_lvalue(&mut self, _: &LValue) -> bool { true } fn visit_lvalue_ident(&mut self, _: &Ident) {} + fn visit_lvalue_member_access( + &mut self, + _object: &LValue, + _field_name: &Ident, + _span: Span, + ) -> bool { + true + } + + fn visit_lvalue_index(&mut self, _array: &LValue, _index: &Expression, _span: Span) -> bool { + true + } + + fn visit_lvalue_dereference(&mut self, _lvalue: &LValue, _span: Span) -> bool { + true + } + + fn visit_lvalue_interned(&mut self, _id: InternedExpressionKind, _span: Span) {} + fn visit_for_range(&mut self, _: &ForRange) -> bool { true } @@ -283,6 +373,37 @@ pub trait Visitor { true } + fn visit_expression_type(&mut self, _: &UnresolvedTypeExpression, _: Span) {} + + fn visit_format_string_type( + &mut self, + _: &UnresolvedTypeExpression, + _: &UnresolvedType, + _: Span, + ) -> bool { + true + } + + fn visit_string_type(&mut self, _: &UnresolvedTypeExpression, _: Span) {} + + fn visit_unspecified_type(&mut self, _: Span) {} + + fn visit_quoted_type(&mut self, _: &QuotedType, _: Span) {} + + fn visit_field_element_type(&mut self, _: Span) {} + + fn visit_integer_type(&mut self, _: Signedness, _: IntegerBitSize, _: Span) {} + + fn visit_bool_type(&mut self, _: Span) {} + + fn visit_unit_type(&mut self, _: Span) {} + + fn visit_resolved_type(&mut self, _: QuotedTypeId, _: Span) {} + + fn visit_interned_type(&mut self, _: InternedUnresolvedTypeData, _: Span) {} + + fn visit_error_type(&mut self, _: Span) {} + fn visit_path(&mut self, _: &Path) {} fn visit_generic_type_args(&mut self, _: &GenericTypeArgs) -> bool { @@ -399,9 +520,22 @@ impl TraitImplItem { pub fn accept_children(&self, visitor: &mut impl Visitor) { match self { - TraitImplItem::Function(noir_function) => noir_function.accept(None, visitor), - TraitImplItem::Constant(..) => (), - TraitImplItem::Type { .. } => (), + TraitImplItem::Function(noir_function) => { + if visitor.visit_trait_impl_item_function(noir_function) { + noir_function.accept(None, visitor) + } + } + TraitImplItem::Constant(name, unresolved_type, expression) => { + if visitor.visit_trait_impl_item_constant(name, unresolved_type, expression) { + unresolved_type.accept(visitor); + expression.accept(visitor); + } + } + TraitImplItem::Type { name, alias } => { + if visitor.visit_trait_impl_item_type(name, alias) { + alias.accept(visitor); + } + } } } } @@ -610,10 +744,10 @@ impl Expression { ExpressionKind::AsTraitPath(as_trait_path) => { as_trait_path.accept(self.span, visitor); } - ExpressionKind::Quote(_) - | ExpressionKind::Resolved(_) - | ExpressionKind::Interned(_) - | ExpressionKind::Error => (), + ExpressionKind::Quote(tokens) => visitor.visit_quote(tokens), + ExpressionKind::Resolved(expr_id) => visitor.visit_resolved_expression(*expr_id), + ExpressionKind::Interned(id) => visitor.visit_interned_expression(*id), + ExpressionKind::Error => visitor.visit_error_expression(), } } } @@ -627,15 +761,22 @@ impl Literal { pub fn accept_children(&self, visitor: &mut impl Visitor) { match self { - Literal::Array(array_literal) | Literal::Slice(array_literal) => { - array_literal.accept(visitor); + Literal::Array(array_literal) => { + if visitor.visit_literal_array(array_literal) { + array_literal.accept(visitor); + } + } + Literal::Slice(array_literal) => { + if visitor.visit_literal_slice(array_literal) { + array_literal.accept(visitor); + } } - Literal::Bool(_) - | Literal::Integer(_, _) - | Literal::Str(_) - | Literal::RawStr(_, _) - | Literal::FmtStr(_) - | Literal::Unit => (), + Literal::Bool(value) => visitor.visit_literal_bool(*value), + Literal::Integer(value, negative) => visitor.visit_literal_integer(*value, *negative), + Literal::Str(str) => visitor.visit_literal_str(str), + Literal::RawStr(str, length) => visitor.visit_literal_raw_str(str, *length), + Literal::FmtStr(str) => visitor.visit_literal_fmt_str(str), + Literal::Unit => visitor.visit_literal_unit(), } } } @@ -799,10 +940,16 @@ impl ArrayLiteral { pub fn accept_children(&self, visitor: &mut impl Visitor) { match self { - ArrayLiteral::Standard(expressions) => visit_expressions(expressions, visitor), + ArrayLiteral::Standard(expressions) => { + if visitor.visit_array_literal_standard(expressions) { + visit_expressions(expressions, visitor) + } + } ArrayLiteral::Repeated { repeated_element, length } => { - repeated_element.accept(visitor); - length.accept(visitor); + if visitor.visit_array_literal_repeated(repeated_element, length) { + repeated_element.accept(visitor); + length.accept(visitor); + } } } } @@ -840,10 +987,10 @@ impl Statement { StatementKind::Semi(expression) => { expression.accept(visitor); } - StatementKind::Break - | StatementKind::Continue - | StatementKind::Interned(_) - | StatementKind::Error => (), + StatementKind::Break => visitor.visit_break(), + StatementKind::Continue => visitor.visit_continue(), + StatementKind::Interned(id) => visitor.visit_interned_statement(*id), + StatementKind::Error => visitor.visit_error_statement(), } } } @@ -913,13 +1060,23 @@ impl LValue { pub fn accept_children(&self, visitor: &mut impl Visitor) { match self { LValue::Ident(ident) => visitor.visit_lvalue_ident(ident), - LValue::MemberAccess { object, field_name: _, span: _ } => object.accept(visitor), - LValue::Index { array, index, span: _ } => { - array.accept(visitor); - index.accept(visitor); + LValue::MemberAccess { object, field_name, span } => { + if visitor.visit_lvalue_member_access(object, field_name, *span) { + object.accept(visitor) + } + } + LValue::Index { array, index, span } => { + if visitor.visit_lvalue_index(array, index, *span) { + array.accept(visitor); + index.accept(visitor); + } } - LValue::Dereference(lvalue, _) => lvalue.accept(visitor), - LValue::Interned(..) => (), + LValue::Dereference(lvalue, span) => { + if visitor.visit_lvalue_dereference(lvalue, *span) { + lvalue.accept(visitor) + } + } + LValue::Interned(id, span) => visitor.visit_lvalue_interned(*id, *span), } } } @@ -1014,18 +1171,24 @@ impl UnresolvedType { as_trait_path.accept(self.span, visitor); } } - UnresolvedTypeData::Expression(_) - | UnresolvedTypeData::FormatString(_, _) - | UnresolvedTypeData::String(_) - | UnresolvedTypeData::Unspecified - | UnresolvedTypeData::Quoted(_) - | UnresolvedTypeData::FieldElement - | UnresolvedTypeData::Integer(_, _) - | UnresolvedTypeData::Bool - | UnresolvedTypeData::Unit - | UnresolvedTypeData::Resolved(_) - | UnresolvedTypeData::Interned(_) - | UnresolvedTypeData::Error => (), + UnresolvedTypeData::Expression(expr) => visitor.visit_expression_type(expr, self.span), + UnresolvedTypeData::FormatString(expr, typ) => { + if visitor.visit_format_string_type(expr, typ, self.span) { + typ.accept(visitor); + } + } + UnresolvedTypeData::String(expr) => visitor.visit_string_type(expr, self.span), + UnresolvedTypeData::Unspecified => visitor.visit_unspecified_type(self.span), + UnresolvedTypeData::Quoted(typ) => visitor.visit_quoted_type(typ, self.span), + UnresolvedTypeData::FieldElement => visitor.visit_field_element_type(self.span), + UnresolvedTypeData::Integer(signdness, size) => { + visitor.visit_integer_type(*signdness, *size, self.span) + } + UnresolvedTypeData::Bool => visitor.visit_bool_type(self.span), + UnresolvedTypeData::Unit => visitor.visit_unit_type(self.span), + UnresolvedTypeData::Resolved(id) => visitor.visit_resolved_type(*id, self.span), + UnresolvedTypeData::Interned(id) => visitor.visit_interned_type(*id, self.span), + UnresolvedTypeData::Error => visitor.visit_error_type(self.span), } } } From 7e606d0a3d7d5308427025bfe5f3d2aa1054a7fd Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 10:41:42 -0300 Subject: [PATCH 4/8] clippy --- compiler/noirc_frontend/src/ast/visitor.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index 1ea4967bb3e..f7270d7b9b7 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -448,19 +448,19 @@ impl Item { ItemKind::Impl(type_impl) => type_impl.accept(self.span, visitor), ItemKind::Global(let_statement) => { if visitor.visit_global(let_statement, self.span) { - let_statement.accept(visitor) + let_statement.accept(visitor); } } ItemKind::Trait(noir_trait) => noir_trait.accept(self.span, visitor), ItemKind::Import(use_tree) => { if visitor.visit_import(use_tree) { - use_tree.accept(visitor) + use_tree.accept(visitor); } } ItemKind::TypeAlias(noir_type_alias) => noir_type_alias.accept(self.span, visitor), ItemKind::Struct(noir_struct) => noir_struct.accept(self.span, visitor), ItemKind::ModuleDecl(module_declaration) => { - module_declaration.accept(self.span, visitor) + module_declaration.accept(self.span, visitor); } } } @@ -522,7 +522,7 @@ impl TraitImplItem { match self { TraitImplItem::Function(noir_function) => { if visitor.visit_trait_impl_item_function(noir_function) { - noir_function.accept(None, visitor) + noir_function.accept(None, visitor); } } TraitImplItem::Constant(name, unresolved_type, expression) => { @@ -942,7 +942,7 @@ impl ArrayLiteral { match self { ArrayLiteral::Standard(expressions) => { if visitor.visit_array_literal_standard(expressions) { - visit_expressions(expressions, visitor) + visit_expressions(expressions, visitor); } } ArrayLiteral::Repeated { repeated_element, length } => { @@ -1062,7 +1062,7 @@ impl LValue { LValue::Ident(ident) => visitor.visit_lvalue_ident(ident), LValue::MemberAccess { object, field_name, span } => { if visitor.visit_lvalue_member_access(object, field_name, *span) { - object.accept(visitor) + object.accept(visitor); } } LValue::Index { array, index, span } => { @@ -1073,7 +1073,7 @@ impl LValue { } LValue::Dereference(lvalue, span) => { if visitor.visit_lvalue_dereference(lvalue, *span) { - lvalue.accept(visitor) + lvalue.accept(visitor); } } LValue::Interned(id, span) => visitor.visit_lvalue_interned(*id, *span), @@ -1182,7 +1182,7 @@ impl UnresolvedType { UnresolvedTypeData::Quoted(typ) => visitor.visit_quoted_type(typ, self.span), UnresolvedTypeData::FieldElement => visitor.visit_field_element_type(self.span), UnresolvedTypeData::Integer(signdness, size) => { - visitor.visit_integer_type(*signdness, *size, self.span) + visitor.visit_integer_type(*signdness, *size, self.span); } UnresolvedTypeData::Bool => visitor.visit_bool_type(self.span), UnresolvedTypeData::Unit => visitor.visit_unit_type(self.span), From 086488b290ec8411afebf15fc85bd84205cc0b9f Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 10:57:15 -0300 Subject: [PATCH 5/8] Avoid visiting uneeded nodes --- tooling/lsp/src/requests/completion.rs | 6 +----- tooling/lsp/src/requests/signature_help.rs | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index 23dacefafce..92dc9fb48ae 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -805,11 +805,7 @@ impl<'a> Visitor for NodeFinder<'a> { } } - if !self.includes_span(item.span) { - return false; - } - - true + self.includes_span(item.span) } fn visit_import(&mut self, use_tree: &UseTree) -> bool { diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index 23f57dcac4f..8f1bbb1570f 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -9,11 +9,12 @@ use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ CallExpression, ConstrainKind, ConstrainStatement, Expression, ExpressionKind, - FunctionReturnType, MethodCallExpression, Visitor, + FunctionReturnType, MethodCallExpression, Statement, Visitor, }, hir_def::{function::FuncMeta, stmt::HirPattern}, macros_api::NodeInterner, node_interner::ReferenceId, + parser::Item, ParsedModule, Type, }; @@ -315,6 +316,18 @@ impl<'a> SignatureFinder<'a> { } impl<'a> Visitor for SignatureFinder<'a> { + fn visit_item(&mut self, item: &Item) -> bool { + self.includes_span(item.span) + } + + fn visit_statement(&mut self, statement: &Statement) -> bool { + self.includes_span(statement.span) + } + + fn visit_expression(&mut self, expression: &Expression) -> bool { + self.includes_span(expression.span) + } + fn visit_call_expression(&mut self, call_expression: &CallExpression, span: Span) -> bool { call_expression.accept_children(self); From 62486eb8b5a2c71d06514bdfa0f39d37f934b0a8 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 16:07:11 -0300 Subject: [PATCH 6/8] Use the Visitor pattern in DocumentSymbol --- tooling/lsp/src/requests/document_symbol.rs | 437 ++++++++++---------- 1 file changed, 208 insertions(+), 229 deletions(-) diff --git a/tooling/lsp/src/requests/document_symbol.rs b/tooling/lsp/src/requests/document_symbol.rs index bda246f7c98..84378810fd4 100644 --- a/tooling/lsp/src/requests/document_symbol.rs +++ b/tooling/lsp/src/requests/document_symbol.rs @@ -10,9 +10,9 @@ use noirc_errors::Span; use noirc_frontend::{ ast::{ Expression, FunctionReturnType, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, - NoirTraitImpl, TraitImplItem, TraitItem, TypeImpl, UnresolvedType, UnresolvedTypeData, + NoirTraitImpl, TypeImpl, UnresolvedType, UnresolvedTypeData, Visitor, }, - parser::{Item, ItemKind, ParsedSubModule}, + parser::ParsedSubModule, ParsedModule, }; @@ -40,8 +40,7 @@ pub(crate) fn on_document_symbol_request( let (parsed_module, _errors) = noirc_frontend::parse_program(source); let mut collector = DocumentSymbolCollector::new(file_id, args.files); - let mut symbols = Vec::new(); - collector.collect_in_parsed_module(&parsed_module, &mut symbols); + let symbols = collector.collect(&parsed_module); DocumentSymbolResponse::Nested(symbols) }) }); @@ -52,67 +51,107 @@ pub(crate) fn on_document_symbol_request( struct DocumentSymbolCollector<'a> { file_id: FileId, files: &'a FileMap, + symbols: Vec, } impl<'a> DocumentSymbolCollector<'a> { fn new(file_id: FileId, files: &'a FileMap) -> Self { - Self { file_id, files } + Self { file_id, files, symbols: Vec::new() } } - fn collect_in_parsed_module( - &mut self, - parsed_module: &ParsedModule, - symbols: &mut Vec, - ) { - for item in &parsed_module.items { - self.collect_in_item(item, symbols); - } + fn collect(&mut self, parsed_module: &ParsedModule) -> Vec { + parsed_module.accept(self); + + std::mem::take(&mut self.symbols) } - fn collect_in_item(&mut self, item: &Item, symbols: &mut Vec) { - match &item.kind { - ItemKind::Function(noir_function) => { - self.collect_in_noir_function(noir_function, item.span, symbols); - } - ItemKind::Struct(noir_struct) => { - self.collect_in_noir_struct(noir_struct, item.span, symbols); - } - ItemKind::Trait(noir_trait) => { - self.collect_in_noir_trait(noir_trait, item.span, symbols); - } - ItemKind::TraitImpl(noir_trait_impl) => { - self.collect_in_noir_trait_impl(noir_trait_impl, item.span, symbols); - } - ItemKind::Impl(type_impl) => { - self.collect_in_type_impl(type_impl, item.span, symbols); - } - ItemKind::Submodules(parsed_sub_module) => { - self.collect_in_parsed_sub_module(parsed_sub_module, item.span, symbols); - } - ItemKind::Global(let_statement) => { - self.collect_in_global(let_statement, item.span, symbols); - } - ItemKind::Import(..) | ItemKind::TypeAlias(..) | ItemKind::ModuleDecl(..) => (), - } + fn collect_in_type(&mut self, name: &Ident, typ: Option<&UnresolvedType>) { + let Some(name_location) = self.to_lsp_location(name.span()) else { + return; + }; + + let span = if let Some(typ) = typ { + Span::from(name.span().start()..typ.span.end()) + } else { + name.span() + }; + + let Some(location) = self.to_lsp_location(span) else { + return; + }; + + #[allow(deprecated)] + self.symbols.push(DocumentSymbol { + name: name.to_string(), + detail: None, + kind: SymbolKind::TYPE_PARAMETER, + tags: None, + deprecated: None, + range: location.range, + selection_range: name_location.range, + children: None, + }); } - fn collect_in_noir_function( + fn collect_in_constant( &mut self, - noir_function: &NoirFunction, - span: Span, - symbols: &mut Vec, + name: &Ident, + typ: &UnresolvedType, + default_value: Option<&Expression>, ) { + let Some(name_location) = self.to_lsp_location(name.span()) else { + return; + }; + + let mut span = name.span(); + + // If there's a type span, extend the span to include it + span = Span::from(span.start()..typ.span.end()); + + // If there's a default value, extend the span to include it + if let Some(default_value) = default_value { + span = Span::from(span.start()..default_value.span.end()); + } + let Some(location) = self.to_lsp_location(span) else { return; }; + #[allow(deprecated)] + self.symbols.push(DocumentSymbol { + name: name.to_string(), + detail: None, + kind: SymbolKind::CONSTANT, + tags: None, + deprecated: None, + range: location.range, + selection_range: name_location.range, + children: None, + }); + } + + fn to_lsp_location(&self, span: Span) -> Option { + super::to_lsp_location(self.files, self.file_id, span) + } +} + +impl<'a> Visitor for DocumentSymbolCollector<'a> { + fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Option) -> bool { + let Some(span) = span else { + return false; + }; + + let Some(location) = self.to_lsp_location(span) else { + return false; + }; + let Some(selection_location) = self.to_lsp_location(noir_function.name_ident().span()) else { - return; + return false; }; #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: noir_function.name().to_string(), detail: Some(noir_function.def.signature()), kind: SymbolKind::FUNCTION, @@ -122,20 +161,17 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: selection_location.range, children: None, }); + + false } - fn collect_in_noir_struct( - &mut self, - noir_struct: &NoirStruct, - span: Span, - symbols: &mut Vec, - ) { + fn visit_noir_struct(&mut self, noir_struct: &NoirStruct, span: Span) -> bool { let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; let Some(selection_location) = self.to_lsp_location(noir_struct.name.span()) else { - return; + return false; }; let mut children = Vec::new(); @@ -164,7 +200,7 @@ impl<'a> DocumentSymbolCollector<'a> { } #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: noir_struct.name.to_string(), detail: None, kind: SymbolKind::STRUCT, @@ -174,29 +210,31 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: selection_location.range, children: Some(children), }); + + false } - fn collect_in_noir_trait( - &mut self, - noir_trait: &NoirTrait, - span: Span, - symbols: &mut Vec, - ) { + fn visit_noir_trait(&mut self, noir_trait: &NoirTrait, span: Span) -> bool { let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; let Some(selection_location) = self.to_lsp_location(noir_trait.name.span()) else { - return; + return false; }; - let mut children = Vec::new(); + let old_symbols = std::mem::take(&mut self.symbols); + self.symbols = Vec::new(); + for item in &noir_trait.items { - self.collect_in_noir_trait_item(item, &mut children); + item.accept(self); } + let children = std::mem::take(&mut self.symbols); + self.symbols = old_symbols; + #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: noir_trait.name.to_string(), detail: None, kind: SymbolKind::INTERFACE, @@ -206,153 +244,87 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: selection_location.range, children: Some(children), }); - } - - fn collect_in_noir_trait_item( - &mut self, - trait_item: &TraitItem, - symbols: &mut Vec, - ) { - // Ideally `TraitItem` has a `span` for the entire definition, and we'd use that - // for the `range` property. For now we do our best to find a reasonable span. - match trait_item { - TraitItem::Function { name, parameters, return_type, body, .. } => { - let Some(name_location) = self.to_lsp_location(name.span()) else { - return; - }; - - let mut span = name.span(); - - // If there are parameters, extend the span to include the last parameter. - if let Some((param_name, _param_type)) = parameters.last() { - span = Span::from(span.start()..param_name.span().end()); - } - - // If there's a return type, extend the span to include it - match return_type { - FunctionReturnType::Default(return_type_span) => { - span = Span::from(span.start()..return_type_span.end()); - } - FunctionReturnType::Ty(typ) => { - span = Span::from(span.start()..typ.span.end()); - } - } - - // If there's a body, extend the span to include it - if let Some(body) = body { - if let Some(statement) = body.statements.last() { - span = Span::from(span.start()..statement.span.end()); - } - } - - let Some(location) = self.to_lsp_location(span) else { - return; - }; - #[allow(deprecated)] - symbols.push(DocumentSymbol { - name: name.to_string(), - detail: None, - kind: SymbolKind::METHOD, - tags: None, - deprecated: None, - range: location.range, - selection_range: name_location.range, - children: None, - }); - } - TraitItem::Constant { name, typ, default_value } => { - self.collect_in_constant(name, typ, default_value.as_ref(), symbols); - } - TraitItem::Type { name } => { - self.collect_in_type(name, None, symbols); - } - } + false } - fn collect_in_constant( + fn visit_trait_item_function( &mut self, name: &Ident, - typ: &UnresolvedType, - default_value: Option<&Expression>, - symbols: &mut Vec, - ) { + _generics: &noirc_frontend::ast::UnresolvedGenerics, + parameters: &[(Ident, UnresolvedType)], + return_type: &FunctionReturnType, + _where_clause: &[noirc_frontend::ast::UnresolvedTraitConstraint], + body: &Option, + ) -> bool { let Some(name_location) = self.to_lsp_location(name.span()) else { - return; + return false; }; let mut span = name.span(); - // If there's a type span, extend the span to include it - span = Span::from(span.start()..typ.span.end()); + // If there are parameters, extend the span to include the last parameter. + if let Some((param_name, _param_type)) = parameters.last() { + span = Span::from(span.start()..param_name.span().end()); + } - // If there's a default value, extend the span to include it - if let Some(default_value) = default_value { - span = Span::from(span.start()..default_value.span.end()); + // If there's a return type, extend the span to include it + match return_type { + FunctionReturnType::Default(return_type_span) => { + span = Span::from(span.start()..return_type_span.end()); + } + FunctionReturnType::Ty(typ) => { + span = Span::from(span.start()..typ.span.end()); + } + } + + // If there's a body, extend the span to include it + if let Some(body) = body { + if let Some(statement) = body.statements.last() { + span = Span::from(span.start()..statement.span.end()); + } } let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: name.to_string(), detail: None, - kind: SymbolKind::CONSTANT, + kind: SymbolKind::METHOD, tags: None, deprecated: None, range: location.range, selection_range: name_location.range, children: None, }); + + false } - fn collect_in_type( + fn visit_trait_item_constant( &mut self, name: &Ident, - typ: Option<&UnresolvedType>, - symbols: &mut Vec, - ) { - let Some(name_location) = self.to_lsp_location(name.span()) else { - return; - }; - - let span = if let Some(typ) = typ { - Span::from(name.span().start()..typ.span.end()) - } else { - name.span() - }; - - let Some(location) = self.to_lsp_location(span) else { - return; - }; + typ: &UnresolvedType, + default_value: &Option, + ) -> bool { + self.collect_in_constant(name, typ, default_value.as_ref()); + false + } - #[allow(deprecated)] - symbols.push(DocumentSymbol { - name: name.to_string(), - detail: None, - kind: SymbolKind::TYPE_PARAMETER, - tags: None, - deprecated: None, - range: location.range, - selection_range: name_location.range, - children: None, - }); + fn visit_trait_item_type(&mut self, name: &Ident) { + self.collect_in_type(name, None); } - fn collect_in_noir_trait_impl( - &mut self, - noir_trait_impl: &NoirTraitImpl, - span: Span, - symbols: &mut Vec, - ) { + fn visit_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl, span: Span) -> bool { let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; let Some(name_location) = self.to_lsp_location(noir_trait_impl.trait_name.span) else { - return; + return false; }; let mut trait_name = String::new(); @@ -378,13 +350,18 @@ impl<'a> DocumentSymbolCollector<'a> { trait_name.push('>'); } - let mut children = Vec::new(); + let old_symbols = std::mem::take(&mut self.symbols); + self.symbols = Vec::new(); + for trait_impl_item in &noir_trait_impl.items { - self.collect_in_trait_impl_item(trait_impl_item, &mut children); + trait_impl_item.accept(self); } + let children = std::mem::take(&mut self.symbols); + self.symbols = old_symbols; + #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: format!("impl {} for {}", trait_name, noir_trait_impl.object_type), detail: None, kind: SymbolKind::NAMESPACE, @@ -394,54 +371,59 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: name_location.range, children: Some(children), }); + + false } - fn collect_in_trait_impl_item( - &mut self, - trait_impl_item: &TraitImplItem, - symbols: &mut Vec, - ) { - match trait_impl_item { - TraitImplItem::Function(noir_function) => { - let span = Span::from( - noir_function.name_ident().span().start()..noir_function.span().end(), - ); - self.collect_in_noir_function(noir_function, span, symbols); - } - TraitImplItem::Constant(name, typ, default_value) => { - self.collect_in_constant(name, typ, Some(default_value), symbols); - } - TraitImplItem::Type { name, alias } => self.collect_in_type(name, Some(alias), symbols), - } + fn visit_trait_impl_item_function(&mut self, noir_function: &NoirFunction) -> bool { + let span = + Span::from(noir_function.name_ident().span().start()..noir_function.span().end()); + noir_function.accept(Some(span), self); + false } - fn collect_in_type_impl( + fn visit_trait_impl_item_constant( &mut self, - type_impl: &TypeImpl, - span: Span, - symbols: &mut Vec, - ) { + name: &Ident, + typ: &UnresolvedType, + default_value: &Expression, + ) -> bool { + self.collect_in_constant(name, typ, Some(default_value)); + false + } + + fn visit_trait_impl_item_type(&mut self, name: &Ident, alias: &UnresolvedType) -> bool { + self.collect_in_type(name, Some(alias)); + false + } + + fn visit_type_impl(&mut self, type_impl: &TypeImpl, span: Span) -> bool { let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; let UnresolvedTypeData::Named(name_path, ..) = &type_impl.object_type.typ else { - return; + return false; }; let name = name_path.last_ident(); let Some(name_location) = self.to_lsp_location(name.span()) else { - return; + return false; }; - let mut children = Vec::new(); + let old_symbols = std::mem::take(&mut self.symbols); + self.symbols = Vec::new(); + for (noir_function, noir_function_span) in &type_impl.methods { - self.collect_in_noir_function(noir_function, *noir_function_span, &mut children); + noir_function.accept(Some(*noir_function_span), self); } + let children = std::mem::take(&mut self.symbols); + self.symbols = old_symbols; + #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: name.to_string(), detail: None, kind: SymbolKind::NAMESPACE, @@ -451,29 +433,31 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: name_location.range, children: Some(children), }); + + false } - fn collect_in_parsed_sub_module( - &mut self, - parsed_sub_module: &ParsedSubModule, - span: Span, - symbols: &mut Vec, - ) { + fn visit_parsed_submodule(&mut self, parsed_sub_module: &ParsedSubModule, span: Span) -> bool { let Some(name_location) = self.to_lsp_location(parsed_sub_module.name.span()) else { - return; + return false; }; let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; - let mut children = Vec::new(); + let old_symbols = std::mem::take(&mut self.symbols); + self.symbols = Vec::new(); + for item in &parsed_sub_module.contents.items { - self.collect_in_item(item, &mut children); + item.accept(self); } + let children = std::mem::take(&mut self.symbols); + self.symbols = old_symbols; + #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: parsed_sub_module.name.to_string(), detail: None, kind: SymbolKind::MODULE, @@ -483,24 +467,21 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: name_location.range, children: Some(children), }); + + false } - fn collect_in_global( - &mut self, - global: &LetStatement, - span: Span, - symbols: &mut Vec, - ) { + fn visit_global(&mut self, global: &LetStatement, span: Span) -> bool { let Some(name_location) = self.to_lsp_location(global.pattern.span()) else { - return; + return false; }; let Some(location) = self.to_lsp_location(span) else { - return; + return false; }; #[allow(deprecated)] - symbols.push(DocumentSymbol { + self.symbols.push(DocumentSymbol { name: global.pattern.to_string(), detail: None, kind: SymbolKind::CONSTANT, @@ -510,10 +491,8 @@ impl<'a> DocumentSymbolCollector<'a> { selection_range: name_location.range, children: None, }); - } - fn to_lsp_location(&self, span: Span) -> Option { - super::to_lsp_location(self.files, self.file_id, span) + false } } From dc0e2a290d28696896ee2a6f86429b18364c30e4 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 16:35:19 -0300 Subject: [PATCH 7/8] Use the Visitor pattern for inlay hints --- compiler/noirc_frontend/src/ast/visitor.rs | 55 ++- tooling/lsp/src/requests/inlay_hint.rs | 379 +++++++-------------- 2 files changed, 168 insertions(+), 266 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index f7270d7b9b7..f34df091c11 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -21,7 +21,7 @@ use crate::{ }; use super::{ - FunctionReturnType, GenericTypeArgs, IntegerBitSize, Signedness, UnresolvedGenerics, + FunctionReturnType, GenericTypeArgs, IntegerBitSize, Pattern, Signedness, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; @@ -413,6 +413,24 @@ pub trait Visitor { fn visit_function_return_type(&mut self, _: &FunctionReturnType) -> bool { true } + + fn visit_pattern(&mut self, _: &Pattern) -> bool { + true + } + + fn visit_identifier_pattern(&mut self, _: &Ident) {} + + fn visit_mutable_pattern(&mut self, _: &Pattern, _: Span, _is_synthesized: bool) -> bool { + true + } + + fn visit_tuple_pattern(&mut self, _: &[Pattern], _: Span) -> bool { + true + } + + fn visit_struct_pattern(&mut self, _: &Path, _: &[(Ident, Pattern)], _: Span) -> bool { + true + } } impl ParsedModule { @@ -1003,6 +1021,7 @@ impl LetStatement { } pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.pattern.accept(visitor); self.r#type.accept(visitor); self.expression.accept(visitor); } @@ -1231,6 +1250,40 @@ impl FunctionReturnType { } } +impl Pattern { + pub fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_pattern(self) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + Pattern::Identifier(ident) => visitor.visit_identifier_pattern(ident), + Pattern::Mutable(pattern, span, is_synthesized) => { + if visitor.visit_mutable_pattern(pattern, *span, *is_synthesized) { + pattern.accept(visitor); + } + } + Pattern::Tuple(patterns, span) => { + if visitor.visit_tuple_pattern(patterns, *span) { + for pattern in patterns { + pattern.accept(visitor); + } + } + } + Pattern::Struct(path, fields, span) => { + if visitor.visit_struct_pattern(path, fields, *span) { + path.accept(visitor); + for (_, pattern) in fields { + pattern.accept(visitor); + } + } + } + } + } +} + fn visit_expressions(expressions: &[Expression], visitor: &mut impl Visitor) { for expression in expressions { expression.accept(visitor); diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index 2f6e7dede5d..c580a294a41 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -10,14 +10,15 @@ use noirc_errors::{Location, Span}; use noirc_frontend::{ self, ast::{ - BlockExpression, Expression, ExpressionKind, Ident, LetStatement, NoirFunction, Pattern, - Statement, StatementKind, TraitImplItem, TraitItem, UnresolvedTypeData, + CallExpression, Expression, ExpressionKind, ForLoopStatement, Ident, Lambda, LetStatement, + MethodCallExpression, NoirFunction, NoirTraitImpl, Pattern, Statement, TypeImpl, + UnresolvedTypeData, Visitor, }, hir_def::stmt::HirPattern, macros_api::NodeInterner, node_interner::ReferenceId, - parser::{Item, ItemKind}, - ParsedModule, Type, TypeBinding, TypeVariable, TypeVariableKind, + parser::{Item, ParsedSubModule}, + Type, TypeBinding, TypeVariable, TypeVariableKind, }; use crate::{utils, LspState}; @@ -47,7 +48,7 @@ pub(crate) fn on_inlay_hint_request( let mut collector = InlayHintCollector::new(args.files, file_id, args.interner, span, options); - collector.collect_in_parsed_module(&parsed_moduled); + parsed_moduled.accept(&mut collector); collector.inlay_hints }) }); @@ -73,266 +74,6 @@ impl<'a> InlayHintCollector<'a> { ) -> InlayHintCollector<'a> { InlayHintCollector { files, file_id, interner, span, options, inlay_hints: Vec::new() } } - fn collect_in_parsed_module(&mut self, parsed_module: &ParsedModule) { - for item in &parsed_module.items { - self.collect_in_item(item); - } - } - - fn collect_in_item(&mut self, item: &Item) { - if !self.intersects_span(item.span) { - return; - } - - match &item.kind { - ItemKind::Function(noir_function) => { - self.collect_in_noir_function(noir_function, item.span); - } - ItemKind::Trait(noir_trait) => { - for item in &noir_trait.items { - self.collect_in_trait_item(item); - } - } - ItemKind::TraitImpl(noir_trait_impl) => { - for trait_impl_item in &noir_trait_impl.items { - self.collect_in_trait_impl_item(trait_impl_item, item.span); - } - - self.show_closing_brace_hint(item.span, || { - format!( - " impl {} for {}", - noir_trait_impl.trait_name, noir_trait_impl.object_type - ) - }); - } - ItemKind::Impl(type_impl) => { - for (noir_function, span) in &type_impl.methods { - self.collect_in_noir_function(noir_function, *span); - } - - self.show_closing_brace_hint(item.span, || { - format!(" impl {}", type_impl.object_type) - }); - } - ItemKind::Global(let_statement) => self.collect_in_let_statement(let_statement), - ItemKind::Submodules(parsed_submodule) => { - self.collect_in_parsed_module(&parsed_submodule.contents); - - self.show_closing_brace_hint(item.span, || { - if parsed_submodule.is_contract { - format!(" contract {}", parsed_submodule.name) - } else { - format!(" mod {}", parsed_submodule.name) - } - }); - } - ItemKind::ModuleDecl(_) => (), - ItemKind::Import(_) => (), - ItemKind::Struct(_) => (), - ItemKind::TypeAlias(_) => (), - } - } - - fn collect_in_trait_item(&mut self, item: &TraitItem) { - match item { - TraitItem::Function { body, .. } => { - if let Some(body) = body { - self.collect_in_block_expression(body); - } - } - TraitItem::Constant { name: _, typ: _, default_value } => { - if let Some(default_value) = default_value { - self.collect_in_expression(default_value); - } - } - TraitItem::Type { .. } => (), - } - } - - fn collect_in_trait_impl_item(&mut self, item: &TraitImplItem, span: Span) { - match item { - TraitImplItem::Function(noir_function) => { - self.collect_in_noir_function(noir_function, span); - } - TraitImplItem::Constant(_name, _typ, default_value) => { - self.collect_in_expression(default_value); - } - TraitImplItem::Type { .. } => (), - } - } - - fn collect_in_noir_function(&mut self, noir_function: &NoirFunction, span: Span) { - self.collect_in_block_expression(&noir_function.def.body); - - self.show_closing_brace_hint(span, || format!(" fn {}", noir_function.def.name)); - } - - fn collect_in_let_statement(&mut self, let_statement: &LetStatement) { - // Only show inlay hints for let variables that don't have an explicit type annotation - if let UnresolvedTypeData::Unspecified = let_statement.r#type.typ { - self.collect_in_pattern(&let_statement.pattern); - }; - - self.collect_in_expression(&let_statement.expression); - } - - fn collect_in_block_expression(&mut self, block_expression: &BlockExpression) { - for statement in &block_expression.statements { - self.collect_in_statement(statement); - } - } - - fn collect_in_statement(&mut self, statement: &Statement) { - if !self.intersects_span(statement.span) { - return; - } - - match &statement.kind { - StatementKind::Let(let_statement) => self.collect_in_let_statement(let_statement), - StatementKind::Constrain(constrain_statement) => { - self.collect_in_expression(&constrain_statement.0); - } - StatementKind::Expression(expression) => self.collect_in_expression(expression), - StatementKind::Assign(assign_statement) => { - self.collect_in_expression(&assign_statement.expression); - } - StatementKind::For(for_loop_statement) => { - self.collect_in_ident(&for_loop_statement.identifier, false); - self.collect_in_expression(&for_loop_statement.block); - } - StatementKind::Comptime(statement) => self.collect_in_statement(statement), - StatementKind::Semi(expression) => self.collect_in_expression(expression), - StatementKind::Break - | StatementKind::Continue - | StatementKind::Interned(_) - | StatementKind::Error => (), - } - } - - fn collect_in_expression(&mut self, expression: &Expression) { - if !self.intersects_span(expression.span) { - return; - } - - match &expression.kind { - ExpressionKind::Block(block_expression) => { - self.collect_in_block_expression(block_expression); - } - ExpressionKind::Prefix(prefix_expression) => { - self.collect_in_expression(&prefix_expression.rhs); - } - ExpressionKind::Index(index_expression) => { - self.collect_in_expression(&index_expression.collection); - self.collect_in_expression(&index_expression.index); - } - ExpressionKind::Call(call_expression) => { - self.collect_call_parameter_names( - get_expression_name(&call_expression.func), - call_expression.func.span, - &call_expression.arguments, - ); - - self.collect_in_expression(&call_expression.func); - for arg in &call_expression.arguments { - self.collect_in_expression(arg); - } - } - ExpressionKind::MethodCall(method_call_expression) => { - self.collect_call_parameter_names( - Some(method_call_expression.method_name.to_string()), - method_call_expression.method_name.span(), - &method_call_expression.arguments, - ); - - self.collect_in_expression(&method_call_expression.object); - for arg in &method_call_expression.arguments { - self.collect_in_expression(arg); - } - } - ExpressionKind::Constructor(constructor_expression) => { - for (_name, expr) in &constructor_expression.fields { - self.collect_in_expression(expr); - } - } - ExpressionKind::MemberAccess(member_access_expression) => { - self.collect_in_expression(&member_access_expression.lhs); - } - ExpressionKind::Cast(cast_expression) => { - self.collect_in_expression(&cast_expression.lhs); - } - ExpressionKind::Infix(infix_expression) => { - self.collect_in_expression(&infix_expression.lhs); - self.collect_in_expression(&infix_expression.rhs); - } - ExpressionKind::If(if_expression) => { - self.collect_in_expression(&if_expression.condition); - self.collect_in_expression(&if_expression.consequence); - if let Some(alternative) = &if_expression.alternative { - self.collect_in_expression(alternative); - } - } - ExpressionKind::Tuple(expressions) => { - for expression in expressions { - self.collect_in_expression(expression); - } - } - ExpressionKind::Lambda(lambda) => { - for (pattern, typ) in &lambda.parameters { - if matches!(typ.typ, UnresolvedTypeData::Unspecified) { - self.collect_in_pattern(pattern); - } - } - - self.collect_in_expression(&lambda.body); - } - ExpressionKind::Parenthesized(parenthesized) => { - self.collect_in_expression(parenthesized); - } - ExpressionKind::Unquote(expression) => { - self.collect_in_expression(expression); - } - ExpressionKind::Comptime(block_expression, _span) => { - self.collect_in_block_expression(block_expression); - } - ExpressionKind::Unsafe(block_expression, _span) => { - self.collect_in_block_expression(block_expression); - } - ExpressionKind::AsTraitPath(path) => { - self.collect_in_ident(&path.impl_item, true); - } - ExpressionKind::Literal(..) - | ExpressionKind::Variable(..) - | ExpressionKind::Quote(..) - | ExpressionKind::Resolved(..) - | ExpressionKind::Interned(..) - | ExpressionKind::Error => (), - } - } - - fn collect_in_pattern(&mut self, pattern: &Pattern) { - if !self.options.type_hints.enabled { - return; - } - - match pattern { - Pattern::Identifier(ident) => { - self.collect_in_ident(ident, true); - } - Pattern::Mutable(pattern, _span, _is_synthesized) => { - self.collect_in_pattern(pattern); - } - Pattern::Tuple(patterns, _span) => { - for pattern in patterns { - self.collect_in_pattern(pattern); - } - } - Pattern::Struct(_path, patterns, _span) => { - for (_ident, pattern) in patterns { - self.collect_in_pattern(pattern); - } - } - } - } fn collect_in_ident(&mut self, ident: &Ident, editable: bool) { if !self.options.type_hints.enabled { @@ -527,6 +268,114 @@ impl<'a> InlayHintCollector<'a> { } } +impl<'a> Visitor for InlayHintCollector<'a> { + fn visit_item(&mut self, item: &Item) -> bool { + self.intersects_span(item.span) + } + + fn visit_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl, span: Span) -> bool { + self.show_closing_brace_hint(span, || { + format!(" impl {} for {}", noir_trait_impl.trait_name, noir_trait_impl.object_type) + }); + + true + } + + fn visit_type_impl(&mut self, type_impl: &TypeImpl, span: Span) -> bool { + self.show_closing_brace_hint(span, || format!(" impl {}", type_impl.object_type)); + + true + } + + fn visit_parsed_submodule(&mut self, parsed_submodule: &ParsedSubModule, span: Span) -> bool { + self.show_closing_brace_hint(span, || { + if parsed_submodule.is_contract { + format!(" contract {}", parsed_submodule.name) + } else { + format!(" mod {}", parsed_submodule.name) + } + }); + + true + } + + fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Option) -> bool { + if let Some(span) = span { + self.show_closing_brace_hint(span, || format!(" fn {}", noir_function.def.name)); + } + + true + } + + fn visit_statement(&mut self, statement: &Statement) -> bool { + self.intersects_span(statement.span) + } + + fn visit_let_statement(&mut self, let_statement: &LetStatement) -> bool { + // Only show inlay hints for let variables that don't have an explicit type annotation + if let UnresolvedTypeData::Unspecified = let_statement.r#type.typ { + let_statement.pattern.accept(self); + }; + + let_statement.expression.accept(self); + + false + } + + fn visit_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) -> bool { + self.collect_in_ident(&for_loop_statement.identifier, false); + true + } + + fn visit_expression(&mut self, expression: &Expression) -> bool { + self.intersects_span(expression.span) + } + + fn visit_call_expression(&mut self, call_expression: &CallExpression, _: Span) -> bool { + self.collect_call_parameter_names( + get_expression_name(&call_expression.func), + call_expression.func.span, + &call_expression.arguments, + ); + + true + } + + fn visit_method_call_expression( + &mut self, + method_call_expression: &MethodCallExpression, + _: Span, + ) -> bool { + self.collect_call_parameter_names( + Some(method_call_expression.method_name.to_string()), + method_call_expression.method_name.span(), + &method_call_expression.arguments, + ); + + true + } + + fn visit_lambda(&mut self, lambda: &Lambda, _: Span) -> bool { + for (pattern, typ) in &lambda.parameters { + if matches!(typ.typ, UnresolvedTypeData::Unspecified) { + pattern.accept(self); + } + } + + lambda.body.accept(self); + + false + } + + fn visit_pattern(&mut self, _: &Pattern) -> bool { + self.options.type_hints.enabled + } + + fn visit_identifier_pattern(&mut self, ident: &Ident) { + self.collect_in_ident(ident, true); + } +} + fn string_part(str: impl Into) -> InlayHintLabelPart { InlayHintLabelPart { value: str.into(), location: None, tooltip: None, command: None } } From 89f56ecf68f8e2e8e64aa81a4a14179cf20c2ec6 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Fri, 30 Aug 2024 16:42:12 -0300 Subject: [PATCH 8/8] Always have a span when visiting a NoirFunction --- compiler/noirc_frontend/src/ast/visitor.rs | 18 +++++++++++------- tooling/lsp/src/requests/completion.rs | 6 +++--- tooling/lsp/src/requests/document_symbol.rs | 15 ++------------- tooling/lsp/src/requests/inlay_hint.rs | 6 ++---- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index f34df091c11..96183d3322f 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -44,7 +44,7 @@ pub trait Visitor { true } - fn visit_noir_function(&mut self, _: &NoirFunction, _: Option) -> bool { + fn visit_noir_function(&mut self, _: &NoirFunction, _: Span) -> bool { true } @@ -60,7 +60,7 @@ pub trait Visitor { true } - fn visit_trait_impl_item_function(&mut self, _: &NoirFunction) -> bool { + fn visit_trait_impl_item_function(&mut self, _: &NoirFunction, _span: Span) -> bool { true } @@ -459,7 +459,7 @@ impl Item { ItemKind::Submodules(parsed_sub_module) => { parsed_sub_module.accept(self.span, visitor); } - ItemKind::Function(noir_function) => noir_function.accept(Some(self.span), visitor), + ItemKind::Function(noir_function) => noir_function.accept(self.span, visitor), ItemKind::TraitImpl(noir_trait_impl) => { noir_trait_impl.accept(self.span, visitor); } @@ -497,7 +497,7 @@ impl ParsedSubModule { } impl NoirFunction { - pub fn accept(&self, span: Option, visitor: &mut impl Visitor) { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { if visitor.visit_noir_function(self, span) { self.accept_children(visitor); } @@ -539,8 +539,12 @@ impl TraitImplItem { pub fn accept_children(&self, visitor: &mut impl Visitor) { match self { TraitImplItem::Function(noir_function) => { - if visitor.visit_trait_impl_item_function(noir_function) { - noir_function.accept(None, visitor); + let span = Span::from( + noir_function.name_ident().span().start()..noir_function.span().end(), + ); + + if visitor.visit_trait_impl_item_function(noir_function, span) { + noir_function.accept(span, visitor); } } TraitImplItem::Constant(name, unresolved_type, expression) => { @@ -569,7 +573,7 @@ impl TypeImpl { self.object_type.accept(visitor); for (method, span) in &self.methods { - method.accept(Some(*span), visitor); + method.accept(*span, visitor); } } } diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index 92dc9fb48ae..cb27a9684a1 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -843,7 +843,7 @@ impl<'a> Visitor for NodeFinder<'a> { false } - fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Option) -> bool { + fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Span) -> bool { let old_type_parameters = self.type_parameters.clone(); self.collect_type_parameters_in_generics(&noir_function.def.generics); @@ -858,7 +858,7 @@ impl<'a> Visitor for NodeFinder<'a> { self.collect_local_variables(¶m.pattern); } - noir_function.def.body.accept(span, self); + noir_function.def.body.accept(Some(span), self); self.type_parameters = old_type_parameters; @@ -888,7 +888,7 @@ impl<'a> Visitor for NodeFinder<'a> { self.collect_type_parameters_in_generics(&type_impl.generics); for (method, span) in &type_impl.methods { - method.accept(Some(*span), self); + method.accept(*span, self); // Optimization: stop looking in functions past the completion cursor if span.end() as usize > self.byte_index { diff --git a/tooling/lsp/src/requests/document_symbol.rs b/tooling/lsp/src/requests/document_symbol.rs index 84378810fd4..e06a3451681 100644 --- a/tooling/lsp/src/requests/document_symbol.rs +++ b/tooling/lsp/src/requests/document_symbol.rs @@ -136,11 +136,7 @@ impl<'a> DocumentSymbolCollector<'a> { } impl<'a> Visitor for DocumentSymbolCollector<'a> { - fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Option) -> bool { - let Some(span) = span else { - return false; - }; - + fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Span) -> bool { let Some(location) = self.to_lsp_location(span) else { return false; }; @@ -375,13 +371,6 @@ impl<'a> Visitor for DocumentSymbolCollector<'a> { false } - fn visit_trait_impl_item_function(&mut self, noir_function: &NoirFunction) -> bool { - let span = - Span::from(noir_function.name_ident().span().start()..noir_function.span().end()); - noir_function.accept(Some(span), self); - false - } - fn visit_trait_impl_item_constant( &mut self, name: &Ident, @@ -416,7 +405,7 @@ impl<'a> Visitor for DocumentSymbolCollector<'a> { self.symbols = Vec::new(); for (noir_function, noir_function_span) in &type_impl.methods { - noir_function.accept(Some(*noir_function_span), self); + noir_function.accept(*noir_function_span, self); } let children = std::mem::take(&mut self.symbols); diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index c580a294a41..43e0227fa26 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -299,10 +299,8 @@ impl<'a> Visitor for InlayHintCollector<'a> { true } - fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Option) -> bool { - if let Some(span) = span { - self.show_closing_brace_hint(span, || format!(" fn {}", noir_function.def.name)); - } + fn visit_noir_function(&mut self, noir_function: &NoirFunction, span: Span) -> bool { + self.show_closing_brace_hint(span, || format!(" fn {}", noir_function.def.name)); true }