Skip to content

Commit

Permalink
support round-trip for T.Ptr in tvmscript (#11179)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored Apr 30, 2022
1 parent b772d27 commit 552f06e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
10 changes: 7 additions & 3 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,10 +1190,14 @@ def transform_TypeApply(self, node):
)

param_types = []
for param in node.params:
for idx, param in enumerate(node.params):
param_type = self.transform(param)
if not isinstance(param_type, ty.TypeGeneric):
self.report_error(f"Expected a type but found {type(param).__name__}", param.span)
if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx):
self.report_error(
f"Expected a type but found {type(param).__name__} "
f"at {idx}th type argument",
param.span,
)

param_types.append(param_type)

Expand Down
24 changes: 21 additions & 3 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def evaluate(self):
"""Return an actual ir.Type Object that this Generic class wraps"""
raise TypeError("Cannot get tvm.Type from a generic type")

def require_type_generic_at(self, idx): # pylint: disable=unused-argument
"""If True, the `idx`th type argument must be TypeGeneric"""
return True

# This function is added here to avoid a pylint error
# for T.int/float below not being callable
def __call__(self):
Expand Down Expand Up @@ -66,13 +70,27 @@ def evaluate(self):
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
"""TVM script typing class generator for PtrType
[] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType
[] operator is overloaded, accepts a ConcreteType and an optional storage scope string,
returns a ConcreteType wrapping PtrType
"""

def __getitem__(self, vtype):
def __getitem__(self, args):
if isinstance(args, TypeGeneric):
args = [args]
if len(args) == 1:
vtype, scope = args[0], "global"
elif len(args) == 2:
vtype, scope = args[0], args[1]
else:
raise TypeError(f"Illegal type argument num for Ptr")
if not isinstance(vtype, TypeGeneric):
raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
return ConcreteType(tvm.ir.PointerType(vtype.evaluate()))
if not isinstance(scope, str):
raise TypeError(f"Ptr expects storage scope argument be a string")
return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope))

def require_type_generic_at(self, idx):
return idx != 1 # the second argument is storage scope for Ptr


class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
Expand Down
5 changes: 3 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1229,10 +1229,11 @@ Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << tir_prefix_ << ".Ptr[";
doc << Print(node->element_type);
if (!node->storage_scope.empty()) {
doc << node->storage_scope << " ";
doc << ", " << Doc::StrLiteral(node->storage_scope);
}
doc << Print(node->element_type) << "]";
doc << "]";
return doc;
}

Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,18 @@ def string_annotation_of_special_chars():
return string_annotation_of_special_chars


def pointer_type():
@T.prim_func
def func_with_ptr_type_annotations(x: T.Ptr[T.int32], y: T.Ptr[T.int32, "shared"]):
xx = T.allocate([16], "int32", "global")
yy = T.allocate([16], "int32", "shared")
a: T.Ptr[T.int32] = T.address_of(xx[0], dtype="handle")
b: T.Ptr[T.int32, "shared"] = T.address_of(yy[0], dtype="handle")
T.evaluate(T.call_extern("copy", a, b, dtype=""))

return func_with_ptr_type_annotations


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3275,6 +3287,7 @@ def string_annotation_of_special_chars():
parse_bufferslice_as_range_bound,
int64_support,
string_annotation_escaping,
pointer_type,
)


Expand Down

0 comments on commit 552f06e

Please sign in to comment.