Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript][UX] Introduce decorator for deprecation #13941

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/pt_tvmdsoop/tests/test_as_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
@tvm.script.ir_module
class ModuleGPU:
@T.prim_func
def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
def main(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i_0 in T.thread_binding(2, thread="blockIdx.x"):
for i_2 in T.thread_binding(2, thread="threadIdx.x"):
Expand Down
8 changes: 4 additions & 4 deletions apps/pt_tvmdsoop/tests/test_boolean_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_tensor_boolean_operation():
@as_torch
@T.prim_func
def negate_tvmscript(
X: T.Buffer[(8, 8), "bool"],
Y: T.Buffer[(8, 8), "float32"],
Z: T.Buffer[(8, 8), "bool"],
U: T.Buffer[(8, 8), "float32"],
X: T.Buffer((8, 8), "bool"),
Y: T.Buffer((8, 8), "float32"),
Z: T.Buffer((8, 8), "bool"),
U: T.Buffer((8, 8), "float32"),
) -> None:
for i, j in T.grid(8, 8):
with T.block():
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class AssignDocNode : public StmtDocNode {
/*!
* \brief The right hand side of the assignment.
*
* If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]`
* If null, this doc represents declaration, e.g. `A: T.Buffer((1,2))`
* */
Optional<ExprDoc> rhs;
/*! \brief The type annotation of this assignment. */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*
* \code{.py}
* @T.prim_func
* def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* for i in T.serial(0, 16,
* annotations={"software_pipeline_stage": [0, 1],
Expand All @@ -601,7 +601,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*
* \code{.py}
* @T.prim_func
* def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* with T.block():
* T.reads([A[tx, 0:16]])
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,34 @@ def structural_hash(node, map_free_vars=False):
structrual_equal
"""
return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member


def deprecated(
method_name: str,
new_method_name: str,
):
"""A decorator to indicate that a method is deprecated

Parameters
----------
method_name : str
The name of the method to deprecate
new_method_name : str
The name of the new method to use instead
"""
import functools # pylint: disable=import-outside-toplevel
import warnings # pylint: disable=import-outside-toplevel

def _deprecate(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)

return _wrapper

return _deprecate
6 changes: 6 additions & 0 deletions python/tvm/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,36 @@
# under the License.
# pylint: disable=invalid-name
"""The legacy TVM parser """
from .ir.base import deprecated

# pylint: disable=import-outside-toplevel


@deprecated("tvm.parser.parse", "tvm.relay.parse")
def parse(*args, **kwargs):
"""Deprecated, use `tvm.relay.parse` instead"""
from tvm.relay import parse as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.parse_expr", "tvm.relay.parse_expr")
def parse_expr(*args, **kwargs):
"""Deprecated, use `tvm.relay.parse_expr` instead"""
from tvm.relay import parse_expr as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.fromtext", "tvm.relay.fromtext")
def fromtext(*args, **kwargs):
"""Deprecated, use `tvm.relay.fromtext` instead"""
from tvm.relay import fromtext as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.SpanCheck", "tvm.relay.SpanCheck")
def SpanCheck(*args, **kwargs):
"""Deprecated, use `tvm.relay.SpanCheck` instead"""
from tvm.relay import SpanCheck as _impl
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np # type: ignore

from tvm.ir import Range, Type
from tvm.ir.base import deprecated
from tvm.runtime import convert, ndarray
from tvm.target import Target

Expand Down Expand Up @@ -1427,6 +1428,26 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var:
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


@deprecated("T.buffer_var", "T.Ptr")
def buffer_var(dtype: str, storage_scope: str = "global") -> Var:
"""The pointer declaration function.

Parameters
----------
dtype : str
The data type of the pointer.

storage_scope : str
The storage scope of the pointer.

Returns
-------
res : Var
The pointer.
"""
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin
"""Compute the minimum value of two expressions.

Expand Down Expand Up @@ -1703,7 +1724,6 @@ def wrapped(*args, **kwargs):

broadcast = Broadcast
ramp = Ramp
buffer_var = ptr
fabs = abs
tvm_call_packed = call_packed
tvm_call_cpacked = call_cpacked
Expand Down
1 change: 0 additions & 1 deletion python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script Parser utils"""

import inspect
from types import FrameType
from typing import Any, Callable, Dict, List
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
from typing import Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer_decl, ptr
Expand Down Expand Up @@ -49,7 +50,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:

class BufferProxy:
"""Buffer proxy class for constructing tir buffer.
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer[].
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
"""

def __call__(
Expand Down Expand Up @@ -78,6 +79,7 @@ def __call__(
axis_separators=axis_separators,
)

@deprecated("T.Buffer(...)", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
Expand All @@ -88,14 +90,15 @@ def __getitem__(self, keys) -> Buffer:

class PtrProxy:
"""Ptr proxy class for constructing tir pointer.
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[].
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""

def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore

@deprecated("T.Ptr(...)", "T.Ptr(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,13 +1932,13 @@ class object that inherits from `Exception`.
class TestRemoveIf(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()

def before(A: T.Buffer[1, "int32"]):
def before(A: T.Buffer(1, "int32")):
if True:
A[0] = 42
else:
A[0] = 5

def expected(A: T.Buffer[1, "int32"]):
def expected(A: T.Buffer(1, "int32")):
A[0] = 42

"""
Expand Down
Loading