Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Distinguish between void* and handle #14488

Merged
merged 2 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,26 @@ void Evaluate(PrimExpr value);

/*!
* \brief Create a TIR var that represents a pointer
*
* \param dtype The data type of the pointer.
*
* \param storage_scope The storage scope of the pointer.
*
* \param is_size_var Whether the pointer is a size var.
*
* \param is_unknown_type Used to distinguish between
* `PrimType(DataType::Handle())` and
* `PointerType(PrimType(DataType::Void()))`. If true, resolve dtype
* of `Void()` as `PrimType`, and if false resolve dtype of `Void()`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Void is defined as DataType(kHandle, 0, 0), while a legitimate handle has non-zero bits and lanes. Do we need an additional parameter here instead of passing the right thing in dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, because the dtype is providing the type of the pointed-to object, not the type of the pointer itself. While we could special-case passing of a DataType(kHandle, 0, 0) to produce PrimType(DataType::Handle()), that would give unexpected results if somebody tries to write a pointer-to-pointer as T.handle("handle") and accidentally hits the special-case.

* as a `PointerType`.
*
* \return The pointer.
*/
inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), //
String storage_scope = "global", //
bool is_size_var = false) {
inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
String storage_scope = "global", bool is_size_var = false,
bool is_unknown_type = false) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
if (is_unknown_type && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,9 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member


def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
def handle(
dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False
) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
Expand All @@ -1460,7 +1462,15 @@ def handle(dtype: str = "void", storage_scope: str = "global", *, is_size_var: b
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return _ffi_api.Handle(dtype, storage_scope, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member
is_unknown_type = dtype is None
if dtype is None:
dtype = "void"
return _ffi_api.Handle( # type: ignore[attr-defined] # pylint: disable=no-member
dtype,
storage_scope,
is_size_var,
is_unknown_type,
)


def void(expr: Optional[PrimExpr] = None, *, is_size_var: bool = False) -> PrimExpr:
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3332,6 +3332,25 @@ def func():
return func


def test_void_ptr_vs_handle():
"""Distinguish between void* and handle

In the future, perhaps these should be de-duplicated by forbidding
one of the two C++ representations.
"""
# Generates PointerType(PrimType(DataType::Void()))
@T.prim_func
def void_ptr(out_ret_value: T.handle("void")):
T.evaluate(out_ret_value)

# Generates PrimType(DataType::Handle())
@T.prim_func
def handle(out_ret_value: T.handle):
T.evaluate(out_ret_value)

assert not tvm.ir.structural_equal(void_ptr, handle)


def void_ptr():
@T.prim_func
def func(out_ret_value: T.handle("void")):
Expand Down