From 6ef73e0cdb3d108262f40cf5ec0e5b0ce1290405 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 10 Apr 2023 16:15:31 -0500 Subject: [PATCH] [TVMScript] Distinguish between void* and handle (#14488) Prior to this PR, the type annotations `PrimType(DataType::Handle())` and `PointerType(PrimType(DataType::Void()))` are both represented as `T.handle` in TVMScript, which can cause failures to round-trip between TIR and TVMScript. This PR keeps `PrimType(DataType::Handle())` as `T.handle`, but updates the representation of `PointerType(PrimType(DataType::Void()))` to `T.handle("void")` in order to distinguish between these two cases. --- include/tvm/script/ir_builder/tir/ir.h | 18 ++++++++++++++---- python/tvm/script/ir_builder/tir/ir.py | 14 ++++++++++++-- .../unittest/test_tvmscript_roundtrip.py | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index a0343b03955b..9ce478da43cd 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -428,16 +428,26 @@ void Evaluate(PrimExpr value); /*! * \brief Create a TIR var that represents a pointer + * * \param dtype The data type of the pointer. + * * \param storage_scope The storage scope of the pointer. + * * \param is_size_var Whether the pointer is a size var. + * + * \param is_unknown_type Used to distinguish between + * `PrimType(DataType::Handle())` and + * `PointerType(PrimType(DataType::Void()))`. If true, resolve dtype + * of `Void()` as `PrimType`, and if false resolve dtype of `Void()` + * as a `PointerType`. + * * \return The pointer. */ -inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), // - String storage_scope = "global", // - bool is_size_var = false) { +inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), + String storage_scope = "global", bool is_size_var = false, + bool is_unknown_type = false) { Type type_annotation{nullptr}; - if (dtype.is_void() && storage_scope == "global") { + if (is_unknown_type && storage_scope == "global") { type_annotation = PrimType(runtime::DataType::Handle()); } else { type_annotation = PointerType(PrimType(dtype), storage_scope); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c3ced1e0338b..c8285ccc52ce 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1441,7 +1441,9 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member -def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var: +def handle( + dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False +) -> Var: """Create a TIR var that represents a pointer. Parameters @@ -1460,7 +1462,15 @@ def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: b res : PrimExpr The new tir.Var with type handle or casted expression with type handle. """ - return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member + is_unknown_type = dtype is None + if dtype is None: + dtype = "void" + return _ffi_api.Handle( # type: ignore[attr-defined] # pylint: disable=no-member + dtype, + storage_scope, + is_size_var, + is_unknown_type, + ) def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr: diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 52d99550be92..50b0ecdc5c40 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3332,6 +3332,25 @@ def func(): return func +def test_void_ptr_vs_handle(): + """Distinguish between void* and handle + + In the future, perhaps these should be de-duplicated by forbidding + one of the two C++ representations. + """ + # Generates PointerType(PrimType(DataType::Void())) + @T.prim_func + def void_ptr(out_ret_value: T.handle("void")): + T.evaluate(out_ret_value) + + # Generates PrimType(DataType::Handle()) + @T.prim_func + def handle(out_ret_value: T.handle): + T.evaluate(out_ret_value) + + assert not tvm.ir.structural_equal(void_ptr, handle) + + def void_ptr(): @T.prim_func def func(out_ret_value: T.handle("void")):