Skip to content

Commit

Permalink
Merge pull request #371 from julia-vscode/sp/better-destruct-typeinf
Browse files Browse the repository at this point in the history
fix(inf): improve type inference with destrcturing assignment
  • Loading branch information
pfitzseb authored Sep 18, 2023
2 parents 6ab1bf4 + 1795678 commit 4f45e40
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
35 changes: 34 additions & 1 deletion src/type_inf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function infer_type(binding::Binding, scope, state)
end

function infer_type_assignment_rhs(binding, state, scope)
is_destructuring = false
lhs = binding.val.args[1]
rhs = binding.val.args[2]
if is_loop_iter_assignment(binding.val)
settype!(binding, infer_eltype(rhs))
Expand All @@ -43,13 +45,24 @@ function infer_type_assignment_rhs(binding, state, scope)
end
else
if CSTParser.is_func_call(rhs)
if CSTParser.istuple(lhs)
if CSTParser.isparameters(lhs.args[1])
is_destructuring = true
else
return
end
end
callname = CSTParser.get_name(rhs)
if isidentifier(callname)
resolve_ref(callname, scope, state)
if hasref(callname)
rb = get_root_method(refof(callname), state.server)
if (rb isa Binding && (CoreTypes.isdatatype(rb.type) || rb.val isa SymbolServer.DataTypeStore)) || rb isa SymbolServer.DataTypeStore
settype!(binding, rb)
if is_destructuring
infer_destructuring_type(binding, rb)
else
settype!(binding, rb)
end
end
end
end
Expand Down Expand Up @@ -94,6 +107,26 @@ function infer_type_assignment_rhs(binding, state, scope)
end
end

function infer_destructuring_type(binding, rb::SymbolServer.DataTypeStore)
assigned_name = CSTParser.get_name(binding.val)
for (fieldname, fieldtype) in zip(rb.val.fieldnames, rb.val.types)
if fieldname == assigned_name
settype!(binding, fieldtype)
return
end
end
end
function infer_destructuring_type(binding::Binding, rb::EXPR)
assigned_name = string(to_codeobject(binding.name))
scope = scopeof(rb)
names = scope.names
if haskey(names, assigned_name)
b = names[assigned_name]
settype!(binding, b.type)
end
end
infer_destructuring_type(binding, rb::Binding) = infer_destructuring_type(binding, rb.val)

function infer_type_decl(binding, state, scope)
t = binding.val.args[2]
if isidentifier(t)
Expand Down
30 changes: 25 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ f(arg) = arg
# @test parse_and_pass("function f(x::Int) x end")[1][2][3].binding.t == StaticLint.getsymbolserver(server)["Core"].vals["Function"]
let cst = parse_and_pass("""
struct T end
function f(x::T) x end""")
function f(x::T) x end
""")
@test StaticLint.CoreTypes.isdatatype(bindingof(cst.args[1]).type)
@test StaticLint.CoreTypes.isfunction(bindingof(cst.args[2]).type)
@test bindingof(cst.args[2].args[1].args[2]).type == bindingof(cst.args[1])
Expand All @@ -199,7 +200,8 @@ f(arg) = arg
let cst = parse_and_pass("""
struct T end
T() = 1
function f(x::T) x end""")
function f(x::T) x end
""")
@test StaticLint.CoreTypes.isdatatype(bindingof(cst.args[1]).type)
@test StaticLint.CoreTypes.isfunction(bindingof(cst.args[3]).type)
@test bindingof(cst.args[3].args[1].args[2]).type == bindingof(cst.args[1])
Expand All @@ -208,7 +210,8 @@ f(arg) = arg

let cst = parse_and_pass("""
struct T end
t = T()""")
t = T()
""")
@test StaticLint.CoreTypes.isdatatype(bindingof(cst.args[1]).type)
@test bindingof(cst.args[2].args[1]).type == bindingof(cst.args[1])
end
Expand All @@ -222,7 +225,8 @@ f(arg) = arg
import ..B
B.x
end
end""")
end
""")
@test refof(cst.args[1].args[3].args[2].args[3].args[2].args[2].args[1]) == bindingof(cst[1].args[3].args[1].args[3].args[1].args[1])
end

Expand All @@ -235,7 +239,8 @@ f(arg) = arg
end
function f(arg::T1)
arg.field.x
end""");
end
""");
@test refof(cst.args[3].args[2].args[1].args[1].args[1]) == bindingof(cst.args[3].args[1].args[2])
@test refof(cst.args[3].args[2].args[1].args[1].args[2].args[1]) == bindingof(cst.args[2].args[3].args[1])
@test refof(cst.args[3].args[2].args[1].args[2].args[1]) == bindingof(cst.args[1].args[3].args[1])
Expand Down Expand Up @@ -342,6 +347,21 @@ f(arg) = arg
@test refof(cst[3][3][1]) !== nothing
@test refof(cst[3][3][2]) !== nothing
end

let cst = parse_and_pass("""
struct Foo
x::DataType
y::Float64
end
(;x, y) = Foo(1,2)
x
y
""")
mx = cst.args[3].meta
@test mx.ref.type.name.name.name == :DataType
my = cst.args[4].meta
@test my.ref.type.name.name.name == :Float64
end
end

@testset "macros" begin
Expand Down

0 comments on commit 4f45e40

Please sign in to comment.