Skip to content

Commit

Permalink
[TVMSCRIPT] Misc error message improvements (#9543)
Browse files Browse the repository at this point in the history
* [TVMSCRIPT] Misc error message improvements

* only prevent indexing into handles with multiple indexes

* lint
  • Loading branch information
Tristan Konolige authored Dec 3, 2021
1 parent 1d40ffb commit 7f683da
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 37 deletions.
87 changes: 75 additions & 12 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,18 @@ def transform_Assign(self, node):
self.context.remove_symbol(var.name)
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))

self.report_error("Unsupported Assign stmt", node.span)
self.report_error(
"""Assignments should be either
1. A "special statement" with return value
1.1 Buffer = T.match_buffer()/T.buffer_decl()
1.2 Var = T.var()
1.3 Var = T.env_thread()
2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
3. A store into a variable: Var[PrimExpr] = PrimExpr
4. A with scope handler with concise scoping and var def
4.1 var = T.allocate()""",
node.span,
)

def transform_SubscriptAssign(self, node):
"""Visitor for statements of the form :code:`x[1] = 2`."""
Expand All @@ -583,6 +594,12 @@ def transform_SubscriptAssign(self, node):
span=tvm_span_from_synr(node.span),
)
else:
if symbol.dtype == "handle" and len(indexes) != 1:
self.report_error(
"Handles only support one-dimensional indexing. Use `T.match_buffer` to "
"construct a multidimensional buffer from a handle.",
node.params[0].span,
)
if len(indexes) != 1:
self.report_error(
f"Store is only allowed with one index, but {len(indexes)} were provided.",
Expand Down Expand Up @@ -736,9 +753,35 @@ def transform_Call(self, node):
return self.transform_Subscript(node)
if node.func_name.name in self._binop_maker:
lhs = self.transform(node.params[0])
# There is no supertype for everything that can appear in
# an expression, so we manually add what we might get here.
if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
# We would really like to report a more specific
# error here, but this parser contains no distinction
# between parsing statements and parsing expressions. All
# rules just call `transform`.
self.report_error(
f"Left hand side of binary op must be a PrimExpr, "
"but it is a {type(lhs).__name__}",
node.params[0].span,
)
rhs = self.transform(node.params[1])
return self._binop_maker[node.func_name.name](
lhs, rhs, span=tvm_span_from_synr(node.span)
if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
self.report_error(
f"Right hand side of binary op must be a PrimExpr, "
"but it is a {type(rhs).__name__}",
node.params[1].span,
)
return call_with_error_reporting(
self.report_error,
node.span,
lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
lhs, rhs, span=span
),
node,
lhs,
rhs,
tvm_span_from_synr(node.span),
)
if node.func_name.name in self._unaryop_maker:
rhs = self.transform(node.params[0])
Expand All @@ -764,6 +807,8 @@ def transform_Call(self, node):
self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
}
if isinstance(func, tvm.tir.op.Op):
if not "dtype" in kw_args.keys():
self.report_error(f"{func} requires a dtype keyword argument.", node.span)
# pattern 2
return tvm.tir.Call(
kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
Expand Down Expand Up @@ -862,15 +907,33 @@ def transform_Subscript(self, node):

indexes = [self.transform(x) for x in node.params[1].values]
if isinstance(symbol, tvm.tir.expr.Var):
for index in indexes:
if not isinstance(index, (tvm.tir.PrimExpr, int)):
self.report_error(
"Buffer load indexes should be int or PrimExpr, but they are "
+ type(index),
node.span,
)
return tvm.tir.Load(
"float32", symbol, indexes, True, span=tvm_span_from_synr(node.span)
if symbol.dtype == "handle":
self.report_error(
"Cannot read directly from a handle, use `T.match_buffer` "
"to create a buffer to read from.",
node.params[0].span,
)
if len(indexes) > 1:
self.report_error(
"Only a single index can be provided when indexing into a `var`.",
node.params[1].span,
)
index = indexes[0]
if not isinstance(index, (tvm.tir.PrimExpr, int)):
self.report_error(
"Var load index should be an int or PrimExpr, but it is a" + type(index),
node.span,
)

return call_with_error_reporting(
self.report_error,
node.span,
tvm.tir.Load,
"float32",
symbol,
index,
True,
span=tvm_span_from_synr(node.span),
)
elif isinstance(symbol, tvm.tir.Buffer):
return BufferSlice(
Expand Down
51 changes: 26 additions & 25 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,34 @@
namespace tvm {
namespace tir {

#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
using T = Name::ContainerType; \
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
ICHECK(a.dtype() == b.dtype()) \
<< "TypeError: mismatched types. " << a.dtype() << " vs. " << b.dtype() << "\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = a.dtype(); \
node->a = std::move(a); \
node->b = std::move(b); \
node->span = std::move(span); \
data_ = std::move(node); \
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
using T = Name::ContainerType; \
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
<< b.dtype() << "\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = a.dtype(); \
node->a = std::move(a); \
node->b = std::move(b); \
node->span = std::move(span); \
data_ = std::move(node); \
}

#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
using T = Name::ContainerType; \
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = DataType::Bool(a.dtype().lanes()); \
node->a = std::move(a); \
node->b = std::move(b); \
node->span = std::move(span); \
data_ = std::move(node); \
#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
using T = Name::ContainerType; \
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
<< b.dtype() << "\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = DataType::Bool(a.dtype().lanes()); \
node->a = std::move(a); \
node->b = std::move(b); \
node->span = std::move(span); \
data_ = std::move(node); \
}

// Var
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_tvmscript_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ def test_complete_alloc_buffer():
tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func)


@T.prim_func
def load_var() -> None:
d = T.var("float32")
d[1] = d[1]


if __name__ == "__main__":
test_complete_matmul()
test_complete_matmul_original()
Expand Down
63 changes: 63 additions & 0 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,5 +614,68 @@ def test_fuse_fail_nested_loop_outer():
assert expected_sub_error_message in str(execinfo.value)


def load_var_multiple() -> None:
d = T.var("float32")
d[2] = d[2, 1] # error cannot provide two indices to load


def test_load_var():
check_error(load_var_multiple, 3)


def store_var_multiple() -> None:
d = T.var("float32")
d[2, 1] = d[1] # error cannot provide two indices to store


def test_store_var():
check_error(store_var_multiple, 3)


def load_handle(h: T.handle) -> None:
h_ = T.match_buffer(h, [1])
h_[0] = h[0] # error cannot load from handle


def test_load_handle():
check_error(load_var_multiple, 3)


def store_handle(h: T.handle) -> None:
h_ = T.match_buffer(h, [1])
h[0] = h_[0] # error cannot store to handle


def test_store_handle():
check_error(store_var_multiple, 3)


def binop_bad_ast_type(h: T.handle):
h_ = T.match_buffer(h, [1])
h_[0] = h + [2] # error rhs should be a primexpr


def test_binop_bad_ast_type():
check_error(binop_bad_ast_type, 3)


def binop_bad_type(h: T.handle):
h_ = T.match_buffer(h, [1])
h_[0] = h + 2 # error lhs and rhs should be the same type


def test_binop_bad_type():
check_error(binop_bad_type, 3)


def floor_dtype(h: T.handle):
h_ = T.match_buffer(h, [1])
h_[0] = T.floor(2) # error floor requires a dtype


def test_floor_dtype():
check_error(floor_dtype, 3)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 7f683da

Please sign in to comment.