diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 02582e29e323..6cb22aeb5f47 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -566,7 +566,18 @@ def transform_Assign(self, node): self.context.remove_symbol(var.name) return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) - self.report_error("Unsupported Assign stmt", node.span) + self.report_error( + """Assignments should be either + 1. A "special statement" with return value + 1.1 Buffer = T.match_buffer()/T.buffer_decl() + 1.2 Var = T.var() + 1.3 Var = T.env_thread() + 2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr + 3. A store into a variable: Var[PrimExpr] = PrimExpr + 4. A with scope handler with concise scoping and var def + 4.1 var = T.allocate()""", + node.span, + ) def transform_SubscriptAssign(self, node): """Visitor for statements of the form :code:`x[1] = 2`.""" @@ -583,6 +594,12 @@ def transform_SubscriptAssign(self, node): span=tvm_span_from_synr(node.span), ) else: + if symbol.dtype == "handle" and len(indexes) != 1: + self.report_error( + "Handles only support one-dimensional indexing. Use `T.match_buffer` to " + "construct a multidimensional buffer from a handle.", + node.params[0].span, + ) if len(indexes) != 1: self.report_error( f"Store is only allowed with one index, but {len(indexes)} were provided.", @@ -736,9 +753,35 @@ def transform_Call(self, node): return self.transform_Subscript(node) if node.func_name.name in self._binop_maker: lhs = self.transform(node.params[0]) + # There is no supertype for everything that can appear in + # an expression, so we manually add what we might get here. + if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)): + # We would really like to report a more specific + # error here, but this parser contains no distinction + # between parsing statements and parsing expressions. All + # rules just call `transform`. + self.report_error( + f"Left hand side of binary op must be a PrimExpr, " + "but it is a {type(lhs).__name__}", + node.params[0].span, + ) rhs = self.transform(node.params[1]) - return self._binop_maker[node.func_name.name]( - lhs, rhs, span=tvm_span_from_synr(node.span) + if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)): + self.report_error( + f"Right hand side of binary op must be a PrimExpr, " + "but it is a {type(rhs).__name__}", + node.params[1].span, + ) + return call_with_error_reporting( + self.report_error, + node.span, + lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name]( + lhs, rhs, span=span + ), + node, + lhs, + rhs, + tvm_span_from_synr(node.span), ) if node.func_name.name in self._unaryop_maker: rhs = self.transform(node.params[0]) @@ -764,6 +807,8 @@ def transform_Call(self, node): self.transform(k): self.transform(v) for k, v in node.keyword_params.items() } if isinstance(func, tvm.tir.op.Op): + if not "dtype" in kw_args.keys(): + self.report_error(f"{func} requires a dtype keyword argument.", node.span) # pattern 2 return tvm.tir.Call( kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span) @@ -862,15 +907,33 @@ def transform_Subscript(self, node): indexes = [self.transform(x) for x in node.params[1].values] if isinstance(symbol, tvm.tir.expr.Var): - for index in indexes: - if not isinstance(index, (tvm.tir.PrimExpr, int)): - self.report_error( - "Buffer load indexes should be int or PrimExpr, but they are " - + type(index), - node.span, - ) - return tvm.tir.Load( - "float32", symbol, indexes, True, span=tvm_span_from_synr(node.span) + if symbol.dtype == "handle": + self.report_error( + "Cannot read directly from a handle, use `T.match_buffer` " + "to create a buffer to read from.", + node.params[0].span, + ) + if len(indexes) > 1: + self.report_error( + "Only a single index can be provided when indexing into a `var`.", + node.params[1].span, + ) + index = indexes[0] + if not isinstance(index, (tvm.tir.PrimExpr, int)): + self.report_error( + "Var load index should be an int or PrimExpr, but it is a" + type(index), + node.span, + ) + + return call_with_error_reporting( + self.report_error, + node.span, + tvm.tir.Load, + "float32", + symbol, + index, + True, + span=tvm_span_from_synr(node.span), ) elif isinstance(symbol, tvm.tir.Buffer): return BufferSlice( diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1d7c959d990d..fbbd4a9522eb 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -31,33 +31,34 @@ namespace tvm { namespace tir { -#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - ICHECK(a.dtype() == b.dtype()) \ - << "TypeError: mismatched types. " << a.dtype() << " vs. " << b.dtype() << "\n"; \ - ObjectPtr node = make_object(); \ - node->dtype = a.dtype(); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ + ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ + << b.dtype() << "\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = a.dtype(); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } -#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \ - ObjectPtr node = make_object(); \ - node->dtype = DataType::Bool(a.dtype().lanes()); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ + ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ + << b.dtype() << "\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = DataType::Bool(a.dtype().lanes()); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } // Var diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 105b4a2d6a3f..882745704693 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -314,6 +314,12 @@ def test_complete_alloc_buffer(): tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) +@T.prim_func +def load_var() -> None: + d = T.var("float32") + d[1] = d[1] + + if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 11b360287cb7..c5f8993ade58 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -614,5 +614,68 @@ def test_fuse_fail_nested_loop_outer(): assert expected_sub_error_message in str(execinfo.value) +def load_var_multiple() -> None: + d = T.var("float32") + d[2] = d[2, 1] # error cannot provide two indices to load + + +def test_load_var(): + check_error(load_var_multiple, 3) + + +def store_var_multiple() -> None: + d = T.var("float32") + d[2, 1] = d[1] # error cannot provide two indices to store + + +def test_store_var(): + check_error(store_var_multiple, 3) + + +def load_handle(h: T.handle) -> None: + h_ = T.match_buffer(h, [1]) + h_[0] = h[0] # error cannot load from handle + + +def test_load_handle(): + check_error(load_var_multiple, 3) + + +def store_handle(h: T.handle) -> None: + h_ = T.match_buffer(h, [1]) + h[0] = h_[0] # error cannot store to handle + + +def test_store_handle(): + check_error(store_var_multiple, 3) + + +def binop_bad_ast_type(h: T.handle): + h_ = T.match_buffer(h, [1]) + h_[0] = h + [2] # error rhs should be a primexpr + + +def test_binop_bad_ast_type(): + check_error(binop_bad_ast_type, 3) + + +def binop_bad_type(h: T.handle): + h_ = T.match_buffer(h, [1]) + h_[0] = h + 2 # error lhs and rhs should be the same type + + +def test_binop_bad_type(): + check_error(binop_bad_type, 3) + + +def floor_dtype(h: T.handle): + h_ = T.match_buffer(h, [1]) + h_[0] = T.floor(2) # error floor requires a dtype + + +def test_floor_dtype(): + check_error(floor_dtype, 3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))