diff --git a/spec/cli/feat_spec.lua b/spec/cli/feat_spec.lua index 286033387..7741cc8f3 100644 --- a/spec/cli/feat_spec.lua +++ b/spec/cli/feat_spec.lua @@ -43,8 +43,8 @@ local test_cases = { status = 1, match = { "2 errors:", - ":9:22: wrong number of arguments (given 3, expects 2)", - ":19:22: wrong number of arguments (given 3, expects at least 1 and at most 2)", + ":9:22: wrong number of arguments (given 3, expects at most 2)", + ":19:22: wrong number of arguments (given 3, expects at most 2)", } } } diff --git a/spec/pragma/arity_spec.lua b/spec/pragma/arity_spec.lua new file mode 100644 index 000000000..02cf8ccc7 --- /dev/null +++ b/spec/pragma/arity_spec.lua @@ -0,0 +1,237 @@ +local util = require("spec.util") + +describe("pragma arity", function() + describe("on", function() + it("rejects function calls with missing arguments", util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { msg = "wrong number of arguments (given 1, expects 2)" } + })) + + it("accepts optional arguments", util.check([[ + --#pragma arity on + + local function f(x: integer, y?: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + end) + + describe("off", function() + it("accepts function calls with missing arguments", util.check([[ + --#pragma arity off + + local function f(x: integer, y: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + + it("ignores optional argument annotations", util.check([[ + --#pragma arity off + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]])) + end) + + describe("no propagation from required module upwards:", function() + it("on then off, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity off + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("on then on, with errors in both", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "r.tl", y = 5, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("off then on, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity off + + local r = require("r") + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 7, filename = "r.tl", msg = "wrong number of arguments (given 1, expects 2)" } + })() + end) + end) + + describe("does propagate downwards into required module:", function() + it("can trigger errors in required modules", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) + ]], { + { filename = "r.tl", y = 4, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 1, expects 3)" }, + })() + end) + + it("can be used to load modules with different settings", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + --#pragma arity off + local r = require("r") + --#pragma arity on + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) -- no error here! + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + end) + + describe("invalid", function() + it("rejects invalid value", util.check_type_error([[ + --#pragma arity invalid_value + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 1, msg = "invalid value for pragma 'arity': invalid_value" } + })) + end) +end) diff --git a/spec/pragma/invalid_spec.lua b/spec/pragma/invalid_spec.lua new file mode 100644 index 000000000..105770481 --- /dev/null +++ b/spec/pragma/invalid_spec.lua @@ -0,0 +1,29 @@ +local util = require("spec.util") + +describe("invalid pragma", function() + it("rejects invalid pragma", util.check_syntax_error([[ + --#invalid_pragma on + ]], { + { y = 1, msg = "invalid token '--#invalid_pragma'" } + })) + + it("pragmas currently do not accept punctuation", util.check_syntax_error([[ + --#pragma something(other) + ]], { + { y = 1, msg = "invalid token '('" }, + { y = 1, msg = "invalid token ')'" }, + })) + + it("pragma arguments need to be in a single line", util.check_syntax_error([[ + --#pragma arity + on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { msg = "expected pragma value" } + })) +end) diff --git a/tl.lua b/tl.lua index d8e10b3e5..fafb1a100 100644 --- a/tl.lua +++ b/tl.lua @@ -804,6 +804,8 @@ end + + @@ -838,6 +840,9 @@ do + + + @@ -874,6 +879,9 @@ do ["number hexfloat"] = "number", ["number power"] = "number", ["number powersign"] = "$ERR invalid_number$", + ["pragma"] = "pragma", + ["pragma any"] = nil, + ["pragma word"] = "pragma_identifier", } local keywords = { @@ -1267,11 +1275,39 @@ do elseif state == "got --" then if c == "[" then state = "got --[" + elseif c == "#" then + state = "pragma" else fwd = false state = "comment short" drop_token() end + elseif state == "pragma" then + if not lex_word[c] then + end_token_prev("pragma") + if tokens[nt].tk ~= "--#pragma" then + add_syntax_error() + end + fwd = false + state = "pragma any" + end + elseif state == "pragma any" then + if c == "\n" then + state = "any" + elseif lex_word[c] then + state = "pragma word" + begin_token() + elseif not lex_space[c] then + begin_token() + end_token_here("$ERR invalid$") + add_syntax_error() + end + elseif state == "pragma word" then + if not lex_word[c] then + end_token_prev("pragma_identifier") + fwd = false + state = (c == "\n") and "any" or "pragma any" + end elseif state == "got 0" then if c == "x" or c == "X" then state = "number hex" @@ -1920,6 +1956,7 @@ end + local TruthyFact = {} @@ -2100,6 +2137,10 @@ local Node = {ExpectedContext = {}, } + + + + @@ -4220,7 +4261,27 @@ do return parse_function(ps, i, "record") end + local function parse_pragma(ps, i) + i = i + 1 + local pragma = new_node(ps, i, "pragma") + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma name") + end + pragma.pkey = ps.tokens[i].tk + i = i + 1 + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma value") + end + pragma.pvalue = ps.tokens[i].tk + i = i + 1 + + return i, pragma + end + local parse_statement_fns = { + ["--#pragma"] = parse_pragma, ["::"] = parse_label, ["do"] = parse_do, ["if"] = parse_if, @@ -4589,6 +4650,7 @@ local no_recurse_node = { ["break"] = true, ["label"] = true, ["number"] = true, + ["pragma"] = true, ["string"] = true, ["boolean"] = true, ["integer"] = true, @@ -5547,6 +5609,8 @@ function tl.pretty_print_ast(ast, gen_target, mode) return out end, }, + ["pragma"] = {}, + ["variable"] = emit_exactly_visitor_cbs, ["identifier"] = emit_exactly_visitor_cbs, @@ -6771,21 +6835,33 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(w, module_name, feat_lax, env) +local function require_module(w, module_name, opts, env) local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$")) then + if found and (opts.feat_lax == "on" or found:match("tl$")) then env.module_filenames[module_name] = found env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + local save_defaults = env.defaults + local defaults = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -6991,7 +7067,11 @@ tl.new_env = function(opts) if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7264,9 +7344,15 @@ do local function show_arity(f) local nfargs = #f.args.tuple - return f.min_arity < nfargs and - "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) or - tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function drop_constant_value(t) @@ -8918,7 +9004,11 @@ a.types[i], b.types[i]), } if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }), + }) end func = self:to_structural(func) @@ -9561,9 +9651,9 @@ a.types[i], b.types[i]), } end end - function TypeChecker:add_function_definition_for_recursion(node, fnargs) + function TypeChecker:add_function_definition_for_recursion(node, fnargs, feat_arity) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10281,7 +10371,7 @@ a.types[i], b.types[i]), } local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), rets = a_type(arg2, "tuple", { tuple = {} }), }) @@ -10369,7 +10459,11 @@ a.types[i], b.types[i]), } end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11528,7 +11622,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11539,7 +11633,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11568,7 +11662,7 @@ self:expand_type(node, values, elements) }) self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11601,7 +11695,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11615,7 +11709,7 @@ self:expand_type(node, values, elements) }) end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11686,7 +11780,7 @@ self:expand_type(node, values, elements) }) end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11760,7 +11854,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11786,7 +11880,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12274,6 +12368,22 @@ self:expand_type(node, values, elements) }) return node.newtype end, }, + ["pragma"] = { + after = function(self, node, _children) + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end + return NONE + end, + }, ["error_node"] = { after = function(_self, node, _children) return a_type(node, "invalid", {}) @@ -12461,6 +12571,15 @@ self:expand_type(node, values, elements) }) local visit_type visit_type = { cbs = { + ["function"] = { + before = visit_type_with_typeargs.before, + after = function(self, typ, children) + if self.feat_arity == false then + typ.min_arity = 0 + end + return visit_type_with_typeargs.after(self, typ, children) + end, + }, ["record"] = { before = function(self, typ) self:begin_scope() @@ -12629,7 +12748,6 @@ self:expand_type(node, values, elements) }) visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["function"] = visit_type_with_typeargs visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs diff --git a/tl.tl b/tl.tl index cec12f333..5f150c924 100644 --- a/tl.tl +++ b/tl.tl @@ -793,6 +793,8 @@ local enum TokenKind "identifier" "number" "integer" + "pragma" + "pragma_identifier" "$ERR unfinished_comment$" "$ERR invalid_string$" "$ERR invalid_number$" @@ -840,6 +842,9 @@ do "number hexfloat" "number power" "number powersign" + "pragma" + "pragma word" + "pragma any" end local last_token_kind : {LexState:TokenKind} = { @@ -874,6 +879,9 @@ do ["number hexfloat"] = "number", ["number power"] = "number", ["number powersign"] = "$ERR invalid_number$", + ["pragma"] = "pragma", + ["pragma any"] = nil, -- never in a token + ["pragma word"] = "pragma_identifier", -- never in a token } local keywords: {string:boolean} = { @@ -1267,11 +1275,39 @@ do elseif state == "got --" then if c == "[" then state = "got --[" + elseif c == "#" then + state = "pragma" else fwd = false state = "comment short" drop_token() end + elseif state == "pragma" then + if not lex_word[c] then + end_token_prev("pragma") + if tokens[nt].tk ~= "--#pragma" then + add_syntax_error() + end + fwd = false + state = "pragma any" + end + elseif state == "pragma any" then + if c == "\n" then + state = "any" + elseif lex_word[c] then + state = "pragma word" + begin_token() + elseif not lex_space[c] then + begin_token() + end_token_here("$ERR invalid$") + add_syntax_error() + end + elseif state == "pragma word" then + if not lex_word[c] then + end_token_prev("pragma_identifier") + fwd = false + state = (c == "\n") and "any" or "pragma any" + end elseif state == "got 0" then if c == "x" or c == "X" then state = "number hex" @@ -1902,6 +1938,7 @@ local enum NodeKind "macroexp" "local_macroexp" "interface" + "pragma" "error_node" end @@ -2100,6 +2137,10 @@ local record Node itemtype: Type decltuple: TupleType + -- pragma + pkey: string + pvalue: string + opt: boolean debug_type: Type @@ -4220,7 +4261,27 @@ local function parse_record_function(ps: ParseState, i: integer): integer, Node return parse_function(ps, i, "record") end +local function parse_pragma(ps: ParseState, i: integer): integer, Node + i = i + 1 -- skip "--#pragma" + local pragma = new_node(ps, i, "pragma") + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma name") + end + pragma.pkey = ps.tokens[i].tk + i = i + 1 + + if ps.tokens[i].kind ~= "pragma_identifier" then + return fail(ps, i, "expected pragma value") + end + pragma.pvalue = ps.tokens[i].tk + i = i + 1 + + return i, pragma +end + local parse_statement_fns: {string : function(ParseState, integer):(integer, Node)} = { + ["--#pragma"] = parse_pragma, ["::"] = parse_label, ["do"] = parse_do, ["if"] = parse_if, @@ -4589,6 +4650,7 @@ local no_recurse_node: {NodeKind : boolean} = { ["break"] = true, ["label"] = true, ["number"] = true, + ["pragma"] = true, ["string"] = true, ["boolean"] = true, ["integer"] = true, @@ -5547,6 +5609,8 @@ function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode?: boolean | return out end, }, + ["pragma"] = { + }, ["variable"] = emit_exactly_visitor_cbs, ["identifier"] = emit_exactly_visitor_cbs, @@ -6771,21 +6835,33 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string +local function require_module(w: Where, module_name: string, opts: TypeCheckOptions, env: Env): Type, string local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$") as boolean) then + if found and (opts.feat_lax == "on" or found:match("tl$") as boolean) then env.module_filenames[module_name] = found env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + local save_defaults = env.defaults + local defaults : TypeCheckOptions = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -6991,7 +7067,11 @@ tl.new_env = function(opts?: EnvOptions): Env, string if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7264,9 +7344,15 @@ do local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple - return f.min_arity < nfargs - and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) - or tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function drop_constant_value(t: Type): Type @@ -8918,7 +9004,11 @@ do -- resolve unknown in lax mode, produce a general unknown function if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }) + }) end -- unwrap if tuple, resolve if nominal func = self:to_structural(func) @@ -9561,9 +9651,9 @@ do end end - function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType, feat_arity: boolean) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10281,7 +10371,7 @@ do local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_tuple(arg2, { a_type(arg2, "any", {}) }), rets = a_tuple(arg2, {}) }) @@ -10369,7 +10459,11 @@ do end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts: TypeCheckOptions = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11528,7 +11622,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11539,7 +11633,7 @@ do self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11568,7 +11662,7 @@ do self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11601,7 +11695,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11615,7 +11709,7 @@ do end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11686,7 +11780,7 @@ do end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11760,7 +11854,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11786,7 +11880,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12274,6 +12368,22 @@ do return node.newtype end, }, + ["pragma"] = { + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end + return NONE + end, + }, ["error_node"] = { after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return an_invalid(node) @@ -12461,6 +12571,15 @@ do local visit_type: Visitor visit_type = { cbs = { + ["function"] = { + before = visit_type_with_typeargs.before, + after = function(self: TypeChecker, typ: FunctionType, children: {Type}): Type + if self.feat_arity == false then + typ.min_arity = 0 + end + return visit_type_with_typeargs.after(self, typ, children) + end + }, ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() @@ -12629,7 +12748,6 @@ do visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["function"] = visit_type_with_typeargs visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs