Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Unify T.handle and T.Ptr #13969

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
void Evaluate(PrimExpr value);

/*!
* \brief The pointer declaration function.
* \brief Create a TIR var that represents a pointer
* \param dtype The data type of the pointer.
* \param storage_scope The storage scope of the pointer.
* \return The pointer.
*/
PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String storage_scope = "global");

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
Expand Down Expand Up @@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint: disable=no-member


def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
"""Construct a new tir.Var with type handle or cast expression to type handle.
def handle(dtype: str = "void", storage_scope: str = "global") -> Var:
"""Create a TIR var that represents a pointer.

Parameters
----------
expr: PrimExpr
The expression to be cast.
dtype: str
The data type of the pointer.

storage_scope: str
The storage scope of the pointer.

Returns
-------
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __call__(
axis_separators=axis_separators,
)

@deprecated("T.Buffer(...)", "T.Buffer(...)")
@deprecated("T.Buffer[...]", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
Expand All @@ -93,12 +93,13 @@ class PtrProxy:
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""

@deprecated("T.Ptr(...)", "T.handle(...)")
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore

@deprecated("T.Ptr(...)", "T.Ptr(...)")
@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
Expand Down
10 changes: 10 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
}

Var Handle(runtime::DataType dtype, String storage_scope) {
Type type_annotation{nullptr};
if (dtype.is_void() && storage_scope == "global") {
type_annotation = PrimType(runtime::DataType::Handle());
} else {
type_annotation = PointerType(PrimType(dtype), storage_scope);
}
return tvm::tir::Var("", type_annotation);
}

using tvm::script::ir_builder::details::Namer;

TVM_STATIC_IR_FUNCTOR(Namer, vtable)
Expand Down
6 changes: 3 additions & 3 deletions src/script/printer/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
element_type = d->AsDoc<ExprDoc>(ty->element_type, ty_p->Attr("element_type"));
}
if (ty->storage_scope == "") {
return TIR(d, "Ptr")->Call({element_type});
return TIR(d, "handle")->Call({element_type});
} else {
return TIR(d, "Ptr")->Call(
{element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
return TIR(d, "handle")
->Call({element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))});
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func():
class Module:
@T.prim_func
def __tvm_main__(
a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8), constants: T.Ptr(T.uint8)
a: T.handle, output: T.handle, workspace: T.handle("uint8"), constants: T.handle("uint8")
) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": ["test_device"]})
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/aot/test_pass_aot_lower_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
def func(a: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_mod___tvm_main__", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []})
tmp_read = T.Ptr("uint8", "")
tmp_read = T.handle("uint8", "")
# buffer definition
tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read)
a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16)
output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16)
# body
tmp_write: T.Ptr(T.uint8) = output_buffer.data
tmp_write: T.handle("uint8") = output_buffer.data
tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write)
for i in T.serial(140):
tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_buffer_conditional_lowering():
"""

@T.prim_func
def before(A: T.Ptr("float32")):
def before(A: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in range(1):
A_1 = T.Buffer((1,), data=A)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main():
T.func_attr({"from_legacy_te_schedule": True})

# If a pointer defined using a LetStmt,
A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function", dtype="handle")
A_data: T.handle("int32") = T.call_extern("dummy_extern_function", dtype="handle")

# and a buffer is backed by that pointer,
A = T.decl_buffer([1], dtype="float32", data=A_data)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare):
"""

def before() -> None:
A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle")
A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([8], "int32", data=A_data)
A[0:8] = T.broadcast(42, 8)

def expected() -> None:
A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle")
A_data: T.handle("int32x8") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([1], "int32x8", data=A_data)
A[0] = T.broadcast(42, 8)

Expand Down
Loading