diff --git a/spec/metamethods/add_spec.lua b/spec/metamethods/add_spec.lua index aa8466981..89555a9c6 100644 --- a/spec/metamethods/add_spec.lua +++ b/spec/metamethods/add_spec.lua @@ -69,4 +69,35 @@ describe("binary metamethod __add", function() print((10 + s).x) ]])) + + it("preserves nominal type checking when resolving metamethods for operators", util.check_type_error([[ + local type Temperature = record + n: number + metamethod __add: function(t1: Temperature, t2: Temperature): Temperature + end + + local type Date = record + n: number + metamethod __add: function(t1: Date, t2: Date): Date + end + + local temp2: Temperature = { n = 45 } + local birthday2 : Date = { n = 34 } + + setmetatable(temp2, { + __add = function(t1: Temperature, t2: Temperature): Temperature + return { n = t1.n + t2.n } + end, + }) + + setmetatable(birthday2, { + __add = function(t1: Date, t2: Date): Date + return { n = t1.n + t2.n } + end, + }) + + print((temp2 + birthday2).n) + ]], { + { y = 26, msg = "Date is not a Temperature" }, + })) end) diff --git a/tl.lua b/tl.lua index a17a96991..1288f5297 100644 --- a/tl.lua +++ b/tl.lua @@ -7621,7 +7621,7 @@ tl.type_check = function(ast, opts) end end - local function check_metamethod(node, op, a, b) + local function check_metamethod(node, op, a, b, orig_a, orig_b) local method_name local where_args local args @@ -7636,11 +7636,11 @@ tl.type_check = function(ast, opts) if a and b then method_name = binop_to_metamethod[op] where_args = { node.e1, node.e2 } - args = { typename = "tuple", a, b } + args = { typename = "tuple", orig_a, orig_b } else method_name = unop_to_metamethod[op] where_args = { node.e1 } - args = { typename = "tuple", a } + args = { typename = "tuple", orig_a } end local metamethod = a.meta_fields and a.meta_fields[method_name or ""] @@ -7679,7 +7679,7 @@ tl.type_check = function(ast, opts) return tbl.fields[key] end - local meta_t = check_metamethod(rec, "@index", tbl, STRING) + local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING) if meta_t then return meta_t end @@ -8050,7 +8050,7 @@ tl.type_check = function(ast, opts) errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b end - local meta_t = check_metamethod(anode, "@index", a, orig_b) + local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b) if meta_t then return meta_t end @@ -10004,7 +10004,7 @@ tl.type_check = function(ast, opts) node.type = types_op[a.typename] local meta_on_operator if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a) + node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil) if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) end @@ -10050,7 +10050,7 @@ tl.type_check = function(ast, opts) node.type = types_op[a.typename] and types_op[a.typename][b.typename] local meta_on_operator if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b) + node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b) if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) if node.op.op == "or" and is_valid_union(unite({ orig_a, orig_b })) then diff --git a/tl.tl b/tl.tl index 1c857442a..ecc7ab67e 100644 --- a/tl.tl +++ b/tl.tl @@ -7621,7 +7621,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function check_metamethod(node: Node, op: string, a: Type, b: Type): Type, integer + local function check_metamethod(node: Node, op: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer local method_name: string local where_args: {Node} local args: Type @@ -7636,11 +7636,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if a and b then method_name = binop_to_metamethod[op] where_args = { node.e1, node.e2 } - args = { typename = "tuple", a, b } + args = { typename = "tuple", orig_a, orig_b } else method_name = unop_to_metamethod[op] where_args = { node.e1 } - args = { typename = "tuple", a } + args = { typename = "tuple", orig_a } end local metamethod = a.meta_fields and a.meta_fields[method_name or ""] @@ -7679,7 +7679,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return tbl.fields[key] end - local meta_t = check_metamethod(rec, "@index", tbl, STRING) + local meta_t = check_metamethod(rec, "@index", tbl, STRING, tbl, STRING) if meta_t then return meta_t end @@ -8050,7 +8050,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b end - local meta_t = check_metamethod(anode, "@index", a, orig_b) + local meta_t = check_metamethod(anode, "@index", a, orig_b, orig_a, orig_b) if meta_t then return meta_t end @@ -10004,7 +10004,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.type = types_op[a.typename] local meta_on_operator: integer if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a) + node.type, meta_on_operator = check_metamethod(node, node.op.op, a, nil, orig_a, nil) if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) end @@ -10050,7 +10050,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.type = types_op[a.typename] and types_op[a.typename][b.typename] local meta_on_operator: integer if not node.type then - node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b) + node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b, orig_a, orig_b) if not node.type then node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) if node.op.op == "or" and is_valid_union(unite({orig_a, orig_b})) then