Skip to content

Commit

Permalink
[TVMScript] Distinguish between void* and handle
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Apr 4, 2023
1 parent b724c87 commit ad97955
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
18 changes: 14 additions & 4 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,26 @@ void Evaluate(PrimExpr value);

/*!
* \brief Create a TIR var that represents a pointer
*
* \param dtype The data type of the pointer.
*
* \param storage_scope The storage scope of the pointer.
*
* \param is_size_var Whether the pointer is a size var.
*
* \param is_unknown_type Used to distinguish between
* `PrimType(DataType::Handle())` and
* `PointerType(PrimType(DataType::Void()))`. If true, resolve dtype
* of `Void()` as `PrimType`, and if false resolve dtype of `Void()`
* as a `PointerType`.
*
* \return The pointer.
*/
inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), //
String storage_scope = "global", //
bool is_size_var = false) {
inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
String storage_scope = "global", bool is_size_var = false,
bool is_unknown_type = false) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
if (is_unknown_type && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,9 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member


def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
def handle(
dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False
) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
Expand All @@ -1460,7 +1462,15 @@ def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: b
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member
is_unknown_type = dtype is None
if dtype is None:
dtype = "void"
return _ffi_api.Handle( # type: ignore[attr-defined] # pylint: disable=no-member
dtype,
storage_scope,
is_size_var,
is_unknown_type,
)


def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr:
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3332,6 +3332,25 @@ def func():
return func


def test_void_ptr_vs_handle():
"""Distinguish between void* and handle
In the future, perhaps these should be de-duplicated by forbidding
one of the two C++ representations.
"""
# Generates PointerType(PrimType(DataType::Void()))
@T.prim_func
def void_ptr(out_ret_value: T.handle("void")):
T.evaluate(out_ret_value)

# Generates PrimType(DataType::Handle())
@T.prim_func
def handle(out_ret_value: T.handle):
T.evaluate(out_ret_value)

assert not tvm.ir.structural_equal(void_ptr, handle)


def void_ptr():
@T.prim_func
def func(out_ret_value: T.handle("void")):
Expand Down

0 comments on commit ad97955

Please sign in to comment.