From 552f06ed450d59816eb3a85f7e810d9726dcce26 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 30 Apr 2022 08:15:26 +0800 Subject: [PATCH] support round-trip for T.Ptr in tvmscript (#11179) --- python/tvm/script/parser.py | 10 +++++--- python/tvm/script/tir/ty.py | 24 ++++++++++++++++--- src/printer/tvmscript_printer.cc | 5 ++-- .../unittest/test_tvmscript_roundtrip.py | 13 ++++++++++ 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 13b283bc0c40..c26812db4062 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -1190,10 +1190,14 @@ def transform_TypeApply(self, node): ) param_types = [] - for param in node.params: + for idx, param in enumerate(node.params): param_type = self.transform(param) - if not isinstance(param_type, ty.TypeGeneric): - self.report_error(f"Expected a type but found {type(param).__name__}", param.span) + if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx): + self.report_error( + f"Expected a type but found {type(param).__name__} " + f"at {idx}th type argument", + param.span, + ) param_types.append(param_type) diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index a34169673ed3..dfe2fbbe42e9 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -31,6 +31,10 @@ def evaluate(self): """Return an actual ir.Type Object that this Generic class wraps""" raise TypeError("Cannot get tvm.Type from a generic type") + def require_type_generic_at(self, idx): # pylint: disable=unused-argument + """If True, the `idx`th type argument must be TypeGeneric""" + return True + # This function is added here to avoid a pylint error # for T.int/float below not being callable def __call__(self): @@ -66,13 +70,27 @@ def evaluate(self): class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for PtrType - [] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType + [] operator is overloaded, accepts a ConcreteType and an optional storage scope string, + returns a ConcreteType wrapping PtrType """ - def __getitem__(self, vtype): + def __getitem__(self, args): + if isinstance(args, TypeGeneric): + args = [args] + if len(args) == 1: + vtype, scope = args[0], "global" + elif len(args) == 2: + vtype, scope = args[0], args[1] + else: + raise TypeError(f"Illegal type argument num for Ptr") if not isinstance(vtype, TypeGeneric): raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}") - return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) + if not isinstance(scope, str): + raise TypeError(f"Ptr expects storage scope argument be a string") + return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope)) + + def require_type_generic_at(self, idx): + return idx != 1 # the second argument is storage scope for Ptr class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index da5975cd5e28..aeb118a49c0e 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1229,10 +1229,11 @@ Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) { Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) { Doc doc; doc << tir_prefix_ << ".Ptr["; + doc << Print(node->element_type); if (!node->storage_scope.empty()) { - doc << node->storage_scope << " "; + doc << ", " << Doc::StrLiteral(node->storage_scope); } - doc << Print(node->element_type) << "]"; + doc << "]"; return doc; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 8f83df9c71f3..0437576462c4 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3241,6 +3241,18 @@ def string_annotation_of_special_chars(): return string_annotation_of_special_chars +def pointer_type(): + @T.prim_func + def func_with_ptr_type_annotations(x: T.Ptr[T.int32], y: T.Ptr[T.int32, "shared"]): + xx = T.allocate([16], "int32", "global") + yy = T.allocate([16], "int32", "shared") + a: T.Ptr[T.int32] = T.address_of(xx[0], dtype="handle") + b: T.Ptr[T.int32, "shared"] = T.address_of(yy[0], dtype="handle") + T.evaluate(T.call_extern("copy", a, b, dtype="")) + + return func_with_ptr_type_annotations + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3275,6 +3287,7 @@ def string_annotation_of_special_chars(): parse_bufferslice_as_range_bound, int64_support, string_annotation_escaping, + pointer_type, )