Skip to content

Commit

Permalink
[TVMScript] Distinguish between void* and handle
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Apr 4, 2023
1 parent b724c87 commit cbd086b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
4 changes: 2 additions & 2 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ void Evaluate(PrimExpr value);
*/
inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), //
String storage_scope = "global", //
bool is_size_var = false) {
bool is_size_var = false, bool is_unknown_type = false) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
if (is_unknown_type && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,9 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member


def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
def handle(
dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False
) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
Expand All @@ -1460,7 +1462,15 @@ def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: b
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member
is_unknown_type = dtype is None
if dtype is None:
dtype = "void"
return _ffi_api.Handle( # type: ignore[attr-defined] # pylint: disable=no-member
dtype,
storage_scope,
is_size_var,
is_unknown_type,
)


def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr:
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3332,6 +3332,25 @@ def func():
return func


def test_void_ptr_vs_handle():
"""Distinguish between void* and handle
In the future, perhaps these should be de-duplicated by forbidding
one of the two C++ representations.
"""
# Generates PointerType(PrimType(DataType::Void()))
@T.prim_func
def void_ptr(out_ret_value: T.handle("void")):
T.evaluate(out_ret_value)

# Generates PrimType(DataType::Handle())
@T.prim_func
def handle(out_ret_value: T.handle):
T.evaluate(out_ret_value)

assert not tvm.ir.structural_equal(void_ptr, handle)


def void_ptr():
@T.prim_func
def func(out_ret_value: T.handle("void")):
Expand Down

0 comments on commit cbd086b

Please sign in to comment.