diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 4c5393ded..02a32beb7 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -110,6 +110,17 @@ describe("tl types works like check", function() util.assert_popen_close(0, pd:close()) -- TODO check json output end) + + it("does not crash when a require() expression does not resolve (#778)", function() + local name = util.write_tmp_file(finally, [[ + local type Foo = require("bla").baz + ]]) + local pd = io.popen(util.tl_cmd("types", name, "--gen-target=5.1") .. "2>&1 1>" .. util.os_null, "r") + local output = pd:read("*a") + util.assert_popen_close(1, pd:close()) + assert.match("1 syntax error:", output, 1, true) + -- TODO check json output + end) end) describe("on .lua files", function() diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 17d0d1ba0..e360682c8 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -1152,4 +1152,55 @@ describe("require", function() assert.same({}, result.env.loaded["./types/person.tl"].type_errors) end) end) + + it("in 'local type' accepts dots for extracting nested types", function () + -- ok + util.mock_io(finally, { + ["mod.tl"] = [[ + local record mod + record Foo + something: K + end + end + + return mod + ]], + ["main.tl"] = [[ + local type Foo = require("mod").Foo + local function f(v: Foo) + print(v.something) + end + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({}, result.syntax_errors) + assert.same({}, result.type_errors) + end) + + it("in 'local type' does not accept arbitrary expressions", function () + -- ok + util.mock_io(finally, { + ["mod.tl"] = [[ + local record mod + record Foo + something: K + end + end + + return mod + ]], + ["main.tl"] = [[ + local type Foo = require("mod") + "hello" + local function f(v: Foo) + print(v.something) + end + ]], + }) + local result, err = tl.process("main.tl") + + assert.same({ + { filename = "main.tl", x = 30, y = 1, msg = "require() in type declarations cannot be part of larger expressions" } + }, result.syntax_errors) + end) end) diff --git a/tl.lua b/tl.lua index e550844f5..fbec75771 100644 --- a/tl.lua +++ b/tl.lua @@ -2808,11 +2808,17 @@ do end local function node_is_require_call(n) - if n.e1 and n.e2 and - n.e1.kind == "variable" and n.e1.tk == "require" and + if not (n.e1 and n.e2) then + return nil + end + if n.op and n.op.op == "." then + + return node_is_require_call(n.e1) + elseif n.e1.kind == "variable" and n.e1.tk == "require" and n.e2.kind == "expression_list" and #n.e2 == 1 and n.e2[1].kind == "string" then + return n.e2[1].conststr elseif n.op and n.op.op == "@funcall" and n.e1 and n.e1.tk == "pcall" and @@ -2820,6 +2826,7 @@ do n.e2[1].kind == "variable" and n.e2[1].tk == "require" and n.e2[2].kind == "string" and n.e2[2].conststr then + return n.e2[2].conststr else return nil @@ -4031,11 +4038,20 @@ do if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + i, asgn.value = parse_expression(ps, i) + if asgn.value then + if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then + fail(ps, istart, "require() in type declarations cannot be part of larger expressions") + return i + end + if not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") + return i + end + return i, asgn + else + return i end - return i, asgn end i, asgn.value = parse_newtype(ps, i) @@ -10577,6 +10593,24 @@ self:expand_type(node, values, elements) }) ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) return td + elseif value.kind == "op" and + value.op.op == "." then + + local ty = self:get_typedecl(value.e1) + if ty.typename == "typedecl" then + local def = ty.def + if def.typename == "record" then + local t = def.fields[value.e2.tk] + if t then + return a_type(value, "typedecl", { def = t }) + else + return self.errs:invalid_at(value.e2, "type not found") + end + else + return self.errs:invalid_at(value.e2, "type is not a record") + end + end + return ty else local newtype = value.newtype if newtype.typename == "typealias" then diff --git a/tl.tl b/tl.tl index c4da65e8c..fb691011c 100644 --- a/tl.tl +++ b/tl.tl @@ -2808,18 +2808,25 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node end local function node_is_require_call(n: Node): string - if n.e1 and n.e2 -- literal require call - and n.e1.kind == "variable" and n.e1.tk == "require" + if not (n.e1 and n.e2) then + return nil + end + if n.op and n.op.op == "." then + -- require("str").something + return node_is_require_call(n.e1) + elseif n.e1.kind == "variable" and n.e1.tk == "require" and n.e2.kind == "expression_list" and #n.e2 == 1 and n.e2[1].kind == "string" then + -- require("str") return n.e2[1].conststr - elseif n.op and n.op.op == "@funcall" -- pcall(require, "str") + elseif n.op and n.op.op == "@funcall" and n.e1 and n.e1.tk == "pcall" and n.e2 and #n.e2 == 2 and n.e2[1].kind == "variable" and n.e2[1].tk == "require" and n.e2[2].kind == "string" and n.e2[2].conststr then + -- pcall(require, "str") return n.e2[2].conststr else return nil -- table.insert cares about arity @@ -4031,11 +4038,20 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + i, asgn.value = parse_expression(ps, i) + if asgn.value then + if asgn.value.op and asgn.value.op.op ~= "@funcall" and asgn.value.op.op ~= "." then + fail(ps, istart, "require() in type declarations cannot be part of larger expressions") + return i + end + if not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") + return i + end + return i, asgn + else + return i end - return i, asgn end i, asgn.value = parse_newtype(ps, i) @@ -10577,6 +10593,24 @@ do ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType) return td + elseif value.kind == "op" + and value.op.op == "." + then + local ty = self:get_typedecl(value.e1) + if ty is TypeDeclType then + local def = ty.def + if def is RecordType then + local t = def.fields[value.e2.tk] + if t then + return a_type(value, "typedecl", { def = t } as TypeDeclType) + else + return self.errs:invalid_at(value.e2, "type not found") + end + else + return self.errs:invalid_at(value.e2, "type is not a record") + end + end + return ty else local newtype = value.newtype if newtype is TypeAliasType then