Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: generate __is-aware code for is on unions #819

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions spec/lang/operator/is_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,124 @@ end]]))
end
end
]]))

it("generates type checks expanding unions (#742)", util.gen([[
global record Foo
bar: string
end

global function repro(x:Foo | string | nil): integer
local y = x
if y is string | Foo then
return 1
elseif y is nil then
return 2
end
return 3
end
]], [[
Foo = {}



function repro(x)
local y = x
if type(y) == "string" or type(y) == "table" then
return 1
elseif y == nil then
return 2
end
return 3
end
]]))

it("generates type checks applying __is to discriminated records in unions", util.gen([[
local interface Type
typename: string
end

local record FooType is Type where self.typename == "foo"
end

local record BarType is Type where self.typename == "bar"
end

global function repro(x:Type | string | nil): integer
local y = x
if y is FooType | BarType then
return 1
elseif y is nil then
return 2
end
return 3
end
]], [[










function repro(x)
local y = x
if y.typename == "foo" or y.typename == "bar" then
return 1
elseif y == nil then
return 2
end
return 3
end
]]))

it("generates type checks applying __is to discriminated records in unions expanding alias", util.gen([[
local interface Type
typename: string
end

local record FooType is Type where self.typename == "foo"
end

local record BarType is Type where self.typename == "bar"
end

local type FooBar = FooType | BarType

global function repro(x:Type | string | nil): integer
local y = x
if y is FooBar then
return 1
elseif y is nil then
return 2
end
return 3
end
]], [[












function repro(x)
local y = x
if y.typename == "foo" or y.typename == "bar" then
return 1
elseif y == nil then
return 2
end
return 3
end
]]))
end)

end)
44 changes: 40 additions & 4 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8063,7 +8063,7 @@ do

local immediate, found = find_nominal_type_decl(self, nom)

if type(immediate) == "table" then
if immediate and (immediate.typename == "invalid" or immediate.typename == "typedecl") then
return immediate
end

Expand Down Expand Up @@ -9670,6 +9670,36 @@ a.types[i], b.types[i]), }
end
end

local function make_is_node(self, var, v, t)
local node = node_at(var, { kind = "op", op = { op = "is", arity = 2, prec = 3 } })
node.e1 = var
node.e2 = node_at(var, { kind = "cast", casttype = self:infer_at(var, t) })
self:check_metamethod(node, "__is", self:to_structural(v), self:to_structural(t), v, t)
if node.expanded then
apply_macroexp(node)
end
node.known = IsFact({ var = var.tk, typ = t, w = node })
return node
end

local function convert_is_of_union_to_or_of_is(self, node, v, u)
local var = node.e1
node.op.op = "or"
node.op.arity = 2
node.op.prec = 1
node.e1 = make_is_node(self, var, v, u.types[1])
local at = node
local n = #u.types
for i = 2, n - 1 do
at.e2 = node_at(var, { kind = "op", op = { op = "or", arity = 2, prec = 1 } })
at.e2.e1 = make_is_node(self, var, v, u.types[i])
node.known = OrFact({ f1 = at.e1.known, f2 = at.e2.known, w = node })
at = at.e2
end
at.e2 = make_is_node(self, var, v, u.types[n])
node.known = OrFact({ f1 = at.e1.known, f2 = at.e2.known, w = node })
end

function TypeChecker:match_record_key(tbl, rec, key)
assert(type(tbl) == "table")
assert(type(rec) == "table")
Expand Down Expand Up @@ -12320,9 +12350,15 @@ self:expand_type(node, values, elements) })
if rb.typename == "integer" then
self.all_needs_compat["math"] = true
end
if node.e1.kind == "variable" then
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact({ var = node.e1.tk, typ = ub, w = node })
if ra.typename == "typedecl" then
self.errs:add(node, "can only use 'is' on variables, not types")
elseif node.e1.kind == "variable" then
if rb.typename == "union" then
convert_is_of_union_to_or_of_is(self, node, ra, rb)
else
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact({ var = node.e1.tk, typ = ub, w = node })
end
else
self.errs:add(node, "can only use 'is' on variables")
end
Expand Down
44 changes: 40 additions & 4 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -8063,7 +8063,7 @@ do

local immediate, found = find_nominal_type_decl(self, nom)
-- if it was previously resolved (or a circular require, or an error), return that;
if immediate is InvalidOrTypeDeclType then
if immediate and immediate is InvalidOrTypeDeclType then
return immediate
end

Expand Down Expand Up @@ -9670,6 +9670,36 @@ do
end
end

local function make_is_node(self: TypeChecker, var: Node, v: Type, t: Type): Node
local node = node_at(var, { kind = "op", op = { op = "is", arity = 2, prec = 3 } })
node.e1 = var
node.e2 = node_at(var, { kind = "cast", casttype = self:infer_at(var, t) })
self:check_metamethod(node, "__is", self:to_structural(v), self:to_structural(t), v, t)
if node.expanded then
apply_macroexp(node)
end
node.known = IsFact { var = var.tk, typ = t, w = node }
return node
end

local function convert_is_of_union_to_or_of_is(self: TypeChecker, node: Node, v: Type, u: UnionType)
local var = node.e1
node.op.op = "or"
node.op.arity = 2
node.op.prec = 1
node.e1 = make_is_node(self, var, v, u.types[1])
local at = node
local n = #u.types
for i = 2, n - 1 do
at.e2 = node_at(var, { kind = "op", op = { op = "or", arity = 2, prec = 1 } })
at.e2.e1 = make_is_node(self, var, v, u.types[i])
node.known = OrFact { f1 = at.e1.known, f2 = at.e2.known, w = node }
at = at.e2
end
at.e2 = make_is_node(self, var, v, u.types[n])
node.known = OrFact { f1 = at.e1.known, f2 = at.e2.known, w = node }
end

function TypeChecker:match_record_key(tbl: Type, rec: Node, key: string): Type, string
assert(type(tbl) == "table")
assert(type(rec) == "table")
Expand Down Expand Up @@ -12320,9 +12350,15 @@ do
if rb.typename == "integer" then
self.all_needs_compat["math"] = true
end
if node.e1.kind == "variable" then
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact { var = node.e1.tk, typ = ub, w = node }
if ra is TypeDeclType then
self.errs:add(node, "can only use 'is' on variables, not types")
elseif node.e1.kind == "variable" then
if rb is UnionType then
convert_is_of_union_to_or_of_is(self, node, ra, rb)
else
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact { var = node.e1.tk, typ = ub, w = node }
end
else
self.errs:add(node, "can only use 'is' on variables")
end
Expand Down
Loading