Skip to content

Commit

Permalink
Add keyword support to signature match (#2546)
Browse files Browse the repository at this point in the history
Co-authored-by: Andy Waite <[email protected]>
  • Loading branch information
vinistock and andyw8 authored Sep 13, 2024
1 parent c142481 commit 35f3f92
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 23 deletions.
78 changes: 72 additions & 6 deletions lib/ruby_indexer/lib/ruby_indexer/entry.rb
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,26 @@ def format

# Returns `true` if the given call node arguments array matches this method signature. This method will prefer
# returning `true` for situations that cannot be analyzed statically, like the presence of splats, keyword splats
# or forwarding arguments
# or forwarding arguments.
#
# Since this method is used to detect which overload should be displayed in signature help, it will also return
# `true` if there are missing arguments since the user may not be done typing yet. For example:
#
# ```ruby
# def foo(a, b); end
# # All of the following are considered matches because the user might be in the middle of typing and we have to
# # show them the signature
# foo
# foo(1)
# foo(1, 2)
# ```
sig { params(arguments: T::Array[Prism::Node]).returns(T::Boolean) }
def matches?(arguments)
min_pos = 0
max_pos = T.let(0, Numeric)
max_pos = T.let(0, T.any(Integer, Float))
names = []
has_forward = T.let(false, T::Boolean)
has_keyword_rest = T.let(false, T::Boolean)

@parameters.each do |param|
case param
Expand All @@ -617,15 +632,66 @@ def matches?(arguments)
max_pos = Float::INFINITY
when ForwardingParameter
max_pos = Float::INFINITY
has_forward = true
when KeywordParameter, OptionalKeywordParameter
names << param.name
when KeywordRestParameter
has_keyword_rest = true
end
end

_keyword_hash_node, positional_args = arguments.partition { |arg| arg.is_a?(Prism::KeywordHashNode) }
argument_length_is_unknown = positional_args.any? do |arg|
arg.is_a?(Prism::SplatNode) || arg.is_a?(Prism::ForwardingArgumentsNode)
keyword_hash_nodes, positional_args = arguments.partition { |arg| arg.is_a?(Prism::KeywordHashNode) }
keyword_args = T.cast(keyword_hash_nodes.first, T.nilable(Prism::KeywordHashNode))&.elements
forwarding_arguments, positionals = positional_args.partition do |arg|
arg.is_a?(Prism::ForwardingArgumentsNode)
end

argument_length_is_unknown || (min_pos..max_pos).cover?(positional_args.length)
return true if has_forward && min_pos == 0

# If the only argument passed is a forwarding argument, then anything will match
(positionals.empty? && forwarding_arguments.any?) ||
(
# Check if positional arguments match. This includes required, optional, rest arguments. We also need to
# verify if there's a trailing forwading argument, like `def foo(a, ...); end`
positional_arguments_match?(positionals, forwarding_arguments, keyword_args, min_pos, max_pos) &&
# If the positional arguments match, we move on to checking keyword, optional keyword and keyword rest
# arguments. If there's a forward argument, then it will always match. If the method accepts a keyword rest
# (**kwargs), then we can't analyze statically because the user could be passing a hash and we don't know
# what the runtime values inside the hash are.
#
# If none of those match, then we verify if the user is passing the expect names for the keyword arguments
(has_forward || has_keyword_rest || keyword_arguments_match?(keyword_args, names))
)
end

sig do
params(
positional_args: T::Array[Prism::Node],
forwarding_arguments: T::Array[Prism::Node],
keyword_args: T.nilable(T::Array[Prism::Node]),
min_pos: Integer,
max_pos: T.any(Integer, Float),
).returns(T::Boolean)
end
def positional_arguments_match?(positional_args, forwarding_arguments, keyword_args, min_pos, max_pos)
# If the method accepts at least one positional argument and a splat has been passed
(min_pos > 0 && positional_args.any? { |arg| arg.is_a?(Prism::SplatNode) }) ||
# If there's at least one positional argument unaccounted for and a keyword splat has been passed
(min_pos - positional_args.length > 0 && keyword_args&.any? { |arg| arg.is_a?(Prism::AssocSplatNode) }) ||
# If there's at least one positional argument unaccounted for and a forwarding argument has been passed
(min_pos - positional_args.length > 0 && forwarding_arguments.any?) ||
# If the number of positional arguments is within the expected range
(min_pos > 0 && positional_args.length <= max_pos) ||
(min_pos == 0 && positional_args.empty?)
end

sig { params(args: T.nilable(T::Array[Prism::Node]), names: T::Array[Symbol]).returns(T::Boolean) }
def keyword_arguments_match?(args, names)
return true unless args
return true if args.any? { |arg| arg.is_a?(Prism::AssocSplatNode) }

arg_names = args.filter_map { |arg| arg.key.value.to_sym if arg.is_a?(Prism::AssocNode) }
(arg_names - names).empty?
end
end
end
Expand Down
98 changes: 81 additions & 17 deletions lib/ruby_indexer/test/method_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def bar(a, b = 123)
entry = T.must(@index["bar"].first)

# Matching calls
assert_signature_matches(entry, "bar()")
assert_signature_matches(entry, "bar(1)")
assert_signature_matches(entry, "bar(1, 2)")
assert_signature_matches(entry, "bar(...)")
Expand All @@ -510,15 +511,16 @@ def bar(a, b = 123)
assert_signature_matches(entry, "bar(*a, 2)")
assert_signature_matches(entry, "bar(1, **a)")
assert_signature_matches(entry, "bar(1) {}")
# This call is impossible to analyze statically because it depends on whether there are elements inside `a` or
# not. If there's nothing, the call will fail. But if there's anything inside, the hash will become the first
# positional argument
assert_signature_matches(entry, "bar(**a)")

# Non matching calls

refute_signature_matches(entry, "bar()")
refute_signature_matches(entry, "bar(1, 2, 3)")

# TODO: uncomment after keyword support
# refute_signature_matches(entry, "bar(1, b: 2)")
# refute_signature_matches(entry, "bar(1, 2, c: 3)")
refute_signature_matches(entry, "bar(1, b: 2)")
refute_signature_matches(entry, "bar(1, 2, c: 3)")
end

def test_signature_matches_for_a_method_with_argument_forwarding
Expand Down Expand Up @@ -570,8 +572,7 @@ def bar(a, ...)
assert_signature_matches(entry, "bar(1) {}")
assert_signature_matches(entry, "bar(1, 2, 3)")
assert_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}")

refute_signature_matches(entry, "bar()")
assert_signature_matches(entry, "bar()")
end

def test_signature_matches_for_destructured_parameters
Expand All @@ -585,6 +586,8 @@ def bar(a, (b, c))
entry = T.must(@index["bar"].first)

# All calls with at least one positional argument match
assert_signature_matches(entry, "bar()")
assert_signature_matches(entry, "bar(1)")
assert_signature_matches(entry, "bar(1, 2)")
assert_signature_matches(entry, "bar(...)")
assert_signature_matches(entry, "bar(1, ...)")
Expand All @@ -593,15 +596,11 @@ def bar(a, (b, c))
assert_signature_matches(entry, "bar(*a, 2)")
# This matches because `bar(1, *[], 2)` would result in `bar(1, 2)`, which is a valid call
assert_signature_matches(entry, "bar(1, *a, 2)")
assert_signature_matches(entry, "bar(1, **a)")
assert_signature_matches(entry, "bar(1) {}")

refute_signature_matches(entry, "bar()")
refute_signature_matches(entry, "bar(1)")
refute_signature_matches(entry, "bar(1, **a)")
refute_signature_matches(entry, "bar(1, 2, 3)")
refute_signature_matches(entry, "bar(1) {}")

# TODO: uncomment after keyword support
# refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}")
refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}")
end

def test_signature_matches_for_post_parameters
Expand All @@ -626,11 +625,76 @@ def bar(*splat, a)
assert_signature_matches(entry, "bar(1, **a)")
assert_signature_matches(entry, "bar(1, 2, 3)")
assert_signature_matches(entry, "bar(1) {}")
assert_signature_matches(entry, "bar()")

refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}")
end

def test_signature_matches_for_keyword_parameters
index(<<~RUBY)
class Foo
def bar(a:, b: 123)
end
end
RUBY

entry = T.must(@index["bar"].first)

assert_signature_matches(entry, "bar(...)")
assert_signature_matches(entry, "bar()")
assert_signature_matches(entry, "bar(a: 1)")
assert_signature_matches(entry, "bar(a: 1, b: 32)")

refute_signature_matches(entry, "bar(a: 1, c: 2)")
refute_signature_matches(entry, "bar(1, ...)")
refute_signature_matches(entry, "bar(1) {}")
refute_signature_matches(entry, "bar(1, *a)")
refute_signature_matches(entry, "bar(*a, 2)")
refute_signature_matches(entry, "bar(1, *a, 2)")
refute_signature_matches(entry, "bar(1, **a)")
refute_signature_matches(entry, "bar(*a)")
refute_signature_matches(entry, "bar(1)")
refute_signature_matches(entry, "bar(1, 2)")
refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}")
end

def test_signature_matches_for_keyword_splats
index(<<~RUBY)
class Foo
def bar(a, b:, **kwargs)
end
end
RUBY

entry = T.must(@index["bar"].first)

assert_signature_matches(entry, "bar(...)")
assert_signature_matches(entry, "bar()")
assert_signature_matches(entry, "bar(1)")
assert_signature_matches(entry, "bar(1, b: 2)")
assert_signature_matches(entry, "bar(1, b: 2, c: 3, d: 4)")

refute_signature_matches(entry, "bar(1, 2, b: 2)")
end

def test_partial_signature_matches
# It's important to match signatures partially, because we want to figure out which signature we should show while
# the user is in the middle of typing
index(<<~RUBY)
class Foo
def bar(a:, b:)
end
refute_signature_matches(entry, "bar()")
def baz(a, b)
end
end
RUBY

# TODO: uncomment after keyword support
# refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}")
entry = T.must(@index["bar"].first)
assert_signature_matches(entry, "bar(a: 1)")

entry = T.must(@index["baz"].first)
assert_signature_matches(entry, "baz(1)")
end

private
Expand Down

0 comments on commit 35f3f92

Please sign in to comment.