From 1b12d763109eeb074c568618cb78f8fd0d81819f Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Sun, 16 Jun 2024 13:37:35 -0300 Subject: [PATCH] pragma: arity on/off --- spec/cli/feat_spec.lua | 4 +- spec/pragma/arity_spec.lua | 237 +++++++++++++++++++++++++++++++++++++ tl.lua | 96 +++++++++++---- tl.tl | 91 ++++++++++---- 4 files changed, 384 insertions(+), 44 deletions(-) create mode 100644 spec/pragma/arity_spec.lua diff --git a/spec/cli/feat_spec.lua b/spec/cli/feat_spec.lua index 286033387..7741cc8f3 100644 --- a/spec/cli/feat_spec.lua +++ b/spec/cli/feat_spec.lua @@ -43,8 +43,8 @@ local test_cases = { status = 1, match = { "2 errors:", - ":9:22: wrong number of arguments (given 3, expects 2)", - ":19:22: wrong number of arguments (given 3, expects at least 1 and at most 2)", + ":9:22: wrong number of arguments (given 3, expects at most 2)", + ":19:22: wrong number of arguments (given 3, expects at most 2)", } } } diff --git a/spec/pragma/arity_spec.lua b/spec/pragma/arity_spec.lua new file mode 100644 index 000000000..02cf8ccc7 --- /dev/null +++ b/spec/pragma/arity_spec.lua @@ -0,0 +1,237 @@ +local util = require("spec.util") + +describe("pragma arity", function() + describe("on", function() + it("rejects function calls with missing arguments", util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { msg = "wrong number of arguments (given 1, expects 2)" } + })) + + it("accepts optional arguments", util.check([[ + --#pragma arity on + + local function f(x: integer, y?: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + end) + + describe("off", function() + it("accepts function calls with missing arguments", util.check([[ + --#pragma arity off + + local function f(x: integer, y: integer) + print(x + (y or 20)) + end + + print(f(10)) + ]])) + + it("ignores optional argument annotations", util.check([[ + --#pragma arity off + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]])) + end) + + describe("no propagation from required module upwards:", function() + it("on then off, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity off + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("on then on, with errors in both", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + ]], { + { filename = "r.tl", y = 5, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + + it("off then on, with error in 'on'", function() + util.mock_io(finally, { + ["r.tl"] = [[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]] + }) + util.check_type_error([[ + --#pragma arity off + + local r = require("r") + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 7, filename = "r.tl", msg = "wrong number of arguments (given 1, expects 2)" } + })() + end) + end) + + describe("does propagate downwards into required module:", function() + it("can trigger errors in required modules", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + local r = require("r") + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) + ]], { + { filename = "r.tl", y = 4, msg = "wrong number of arguments (given 1, expects 3)" }, + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 15, msg = "wrong number of arguments (given 2, expects 4)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 1, expects 3)" }, + })() + end) + + it("can be used to load modules with different settings", function() + util.mock_io(finally, { + ["r.tl"] = [[ + local function f(x: integer, y: integer, z: integer) + print(x + (y or 20)) + end + print(f(10)) + + return { + f = f + } + ]] + }) + util.check_type_error([[ + --#pragma arity on + + local function f(x: integer, y: integer) + print(x + y) + end + + print(f(10)) + + --#pragma arity off + local r = require("r") + --#pragma arity on + + local function g(x: integer, y: integer, z: integer, w: integer) + print(x + y) + end + + print(g(10, 20)) + + r.f(10) -- no error here! + ]], { + { filename = "foo.tl", y = 7, msg = "wrong number of arguments (given 1, expects 2)" }, + { filename = "foo.tl", y = 17, msg = "wrong number of arguments (given 2, expects 4)" }, + })() + end) + end) + + describe("invalid", function() + it("rejects invalid value", util.check_type_error([[ + --#pragma arity invalid_value + + local function f(x: integer, y?: integer) + print(x + y) + end + + print(f(10)) + ]], { + { y = 1, msg = "invalid value for pragma 'arity': invalid_value" } + })) + end) +end) diff --git a/tl.lua b/tl.lua index 8eaa70524..fafb1a100 100644 --- a/tl.lua +++ b/tl.lua @@ -1956,6 +1956,7 @@ end + local TruthyFact = {} @@ -2136,6 +2137,10 @@ local Node = {ExpectedContext = {}, } + + + + @@ -6830,21 +6835,33 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(w, module_name, feat_lax, env) +local function require_module(w, module_name, opts, env) local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$")) then + if found and (opts.feat_lax == "on" or found:match("tl$")) then env.module_filenames[module_name] = found env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + local save_defaults = env.defaults + local defaults = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -7050,7 +7067,11 @@ tl.new_env = function(opts) if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7323,9 +7344,15 @@ do local function show_arity(f) local nfargs = #f.args.tuple - return f.min_arity < nfargs and - "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) or - tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function drop_constant_value(t) @@ -8977,7 +9004,11 @@ a.types[i], b.types[i]), } if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }), + }) end func = self:to_structural(func) @@ -9620,9 +9651,9 @@ a.types[i], b.types[i]), } end end - function TypeChecker:add_function_definition_for_recursion(node, fnargs) + function TypeChecker:add_function_definition_for_recursion(node, fnargs, feat_arity) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10340,7 +10371,7 @@ a.types[i], b.types[i]), } local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), rets = a_type(arg2, "tuple", { tuple = {} }), }) @@ -10428,7 +10459,11 @@ a.types[i], b.types[i]), } end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11587,7 +11622,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11598,7 +11633,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11627,7 +11662,7 @@ self:expand_type(node, values, elements) }) self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11660,7 +11695,7 @@ self:expand_type(node, values, elements) }) assert(args.typename == "tuple") self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self, node, children) local args = children[2] @@ -11674,7 +11709,7 @@ self:expand_type(node, values, elements) }) end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11745,7 +11780,7 @@ self:expand_type(node, values, elements) }) end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11819,7 +11854,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11845,7 +11880,7 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12334,7 +12369,18 @@ self:expand_type(node, values, elements) }) end, }, ["pragma"] = { - after = function(_self, _node, _children) + after = function(self, node, _children) + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end return NONE end, }, @@ -12525,6 +12571,15 @@ self:expand_type(node, values, elements) }) local visit_type visit_type = { cbs = { + ["function"] = { + before = visit_type_with_typeargs.before, + after = function(self, typ, children) + if self.feat_arity == false then + typ.min_arity = 0 + end + return visit_type_with_typeargs.after(self, typ, children) + end, + }, ["record"] = { before = function(self, typ) self:begin_scope() @@ -12693,7 +12748,6 @@ self:expand_type(node, values, elements) }) visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["function"] = visit_type_with_typeargs visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs diff --git a/tl.tl b/tl.tl index 08baa5db3..5f150c924 100644 --- a/tl.tl +++ b/tl.tl @@ -6835,21 +6835,33 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string +local function require_module(w: Where, module_name: string, opts: TypeCheckOptions, env: Env): Type, string local mod = env.modules[module_name] if mod then return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (feat_lax or found:match("tl$") as boolean) then + if found and (opts.feat_lax == "on" or found:match("tl$") as boolean) then env.module_filenames[module_name] = found env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + local save_defaults = env.defaults + local defaults : TypeCheckOptions = { + feat_lax = opts.feat_lax or save_defaults.feat_lax, + feat_arity = opts.feat_arity or save_defaults.feat_arity, + gen_compat = opts.gen_compat or save_defaults.gen_compat, + gen_target = opts.gen_target or save_defaults.gen_target, + run_internal_compiler_checks = opts.run_internal_compiler_checks or save_defaults.run_internal_compiler_checks, + } + env.defaults = defaults + local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) + env.defaults = save_defaults + env.modules[module_name] = found_result.type return found_result.type, found @@ -7055,7 +7067,11 @@ tl.new_env = function(opts?: EnvOptions): Env, string if opts.predefined_modules then for _, name in ipairs(opts.predefined_modules) do - local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) + local tc_opts = { + feat_lax = env.defaults.feat_lax, + feat_arity = env.defaults.feat_arity, + } + local module_type = require_module(w, name, tc_opts, env) if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) @@ -7328,9 +7344,15 @@ do local function show_arity(f: FunctionType): string local nfargs = #f.args.tuple - return f.min_arity < nfargs - and "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) - or tostring(nfargs or 0) + if f.min_arity < nfargs then + if f.min_arity > 0 then + return "at least " .. f.min_arity .. (f.args.is_va and "" or " and at most " .. nfargs) + else + return (f.args.is_va and "any number" or "at most " .. nfargs) + end + else + return tostring(nfargs or 0) + end end local function drop_constant_value(t: Type): Type @@ -8982,7 +9004,11 @@ do -- resolve unknown in lax mode, produce a general unknown function if self.feat_lax and is_unknown(func) then local unk = func - func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) + func = a_function(func, { + min_arity = 0, + args = a_vararg(func, { unk }), + rets = a_vararg(func, { unk }) + }) end -- unwrap if tuple, resolve if nominal func = self:to_structural(func) @@ -9625,9 +9651,9 @@ do end end - function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType, feat_arity: boolean) self:add_var(nil, node.name.tk, a_function(node, { - min_arity = node.min_arity, + min_arity = feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), @@ -10345,7 +10371,7 @@ do local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) local msgh_type = a_function(arg2, { - min_arity = 1, + min_arity = self.feat_arity and 1 or 0, args = a_tuple(arg2, { a_type(arg2, "any", {}) }), rets = a_tuple(arg2, {}) }) @@ -10433,7 +10459,11 @@ do end local module_name = assert(node.e2[1].conststr) - local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) + local tc_opts: TypeCheckOptions = { + feat_lax = self.feat_lax and "on" or "off", + feat_arity = self.feat_arity and "on" or "off", + } + local t, module_filename = require_module(node, module_name, tc_opts, self.env) if t.typename == "invalid" then if not module_filename then @@ -11592,7 +11622,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11603,7 +11633,7 @@ do self:end_function_scope(node) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11632,7 +11662,7 @@ do self:check_macroexp_arg_use(node.macrodef) local t = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.macrodef.min_arity, + min_arity = self.feat_arity and node.macrodef.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11665,7 +11695,7 @@ do assert(args is TupleType) self:add_internal_function_variables(node, args) - self:add_function_definition_for_recursion(node, args) + self:add_function_definition_for_recursion(node, args, self.feat_arity) end, after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] @@ -11679,7 +11709,7 @@ do end self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11750,7 +11780,7 @@ do end local fn_type = self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, typeargs = node.typeargs, args = args, @@ -11824,7 +11854,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = self.get_rets(rets), @@ -11850,7 +11880,7 @@ do self:end_function_scope(node) return self:ensure_fresh_typeargs(a_function(node, { - min_arity = node.min_arity, + min_arity = self.feat_arity and node.min_arity or 0, typeargs = node.typeargs, args = args, rets = rets, @@ -12339,7 +12369,18 @@ do end, }, ["pragma"] = { - after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + if node.pkey == "arity" then + if node.pvalue == "on" then + self.feat_arity = true + elseif node.pvalue == "off" then + self.feat_arity = false + else + return self.errs:invalid_at(node, "invalid value for pragma 'arity': " .. node.pvalue) + end + else + return self.errs:invalid_at(node, "invalid pragma: " .. node.pkey) + end return NONE end, }, @@ -12530,6 +12571,15 @@ do local visit_type: Visitor visit_type = { cbs = { + ["function"] = { + before = visit_type_with_typeargs.before, + after = function(self: TypeChecker, typ: FunctionType, children: {Type}): Type + if self.feat_arity == false then + typ.min_arity = 0 + end + return visit_type_with_typeargs.after(self, typ, children) + end + }, ["record"] = { before = function(self: TypeChecker, typ: RecordType) self:begin_scope() @@ -12698,7 +12748,6 @@ do visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["function"] = visit_type_with_typeargs visit_type.cbs["typedecl"] = visit_type_with_typeargs visit_type.cbs["typealias"] = visit_type_with_typeargs