Skip to content

Commit

Permalink
[TVMScript][UX] Introduce decorator for deprecation (#13941)
Browse files Browse the repository at this point in the history
This PR introduces a decorator `tvm.ir.base.deprecated`, which emits a
deprecation warning if an outdated API is used, but preserves backward
compatibility by still allowing the API to be used.

For example, currently the preferred way of TIR buffer declaration in
function signature is:

```python
def example(
  A: T.Buffer(...),  # legacy behavior is `T.Buffer[...]`
): ...
```

With this decorator, if a user writes `T.Buffer[...]`, the parser will
still function properly, but emits a warning that guides the user to
adopt `T.Buffer(...)` if possible.

While there is no breaking change at all in this PR, we believe this
is useful to help users upgrade before any breaking change eventually
takes place.
  • Loading branch information
junrushao authored Feb 10, 2023
1 parent 6f0e2ed commit 256bad7
Show file tree
Hide file tree
Showing 164 changed files with 1,661 additions and 1,607 deletions.
2 changes: 1 addition & 1 deletion apps/pt_tvmdsoop/tests/test_as_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
@tvm.script.ir_module
class ModuleGPU:
@T.prim_func
def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
def main(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i_0 in T.thread_binding(2, thread="blockIdx.x"):
for i_2 in T.thread_binding(2, thread="threadIdx.x"):
Expand Down
8 changes: 4 additions & 4 deletions apps/pt_tvmdsoop/tests/test_boolean_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_tensor_boolean_operation():
@as_torch
@T.prim_func
def negate_tvmscript(
X: T.Buffer[(8, 8), "bool"],
Y: T.Buffer[(8, 8), "float32"],
Z: T.Buffer[(8, 8), "bool"],
U: T.Buffer[(8, 8), "float32"],
X: T.Buffer((8, 8), "bool"),
Y: T.Buffer((8, 8), "float32"),
Z: T.Buffer((8, 8), "bool"),
U: T.Buffer((8, 8), "float32"),
) -> None:
for i, j in T.grid(8, 8):
with T.block():
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class AssignDocNode : public StmtDocNode {
/*!
* \brief The right hand side of the assignment.
*
* If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]`
* If null, this doc represents declaration, e.g. `A: T.Buffer((1,2))`
* */
Optional<ExprDoc> rhs;
/*! \brief The type annotation of this assignment. */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*
* \code{.py}
* @T.prim_func
* def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* for i in T.serial(0, 16,
* annotations={"software_pipeline_stage": [0, 1],
Expand All @@ -601,7 +601,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*
* \code{.py}
* @T.prim_func
* def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* with T.block():
* T.reads([A[tx, 0:16]])
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,34 @@ def structural_hash(node, map_free_vars=False):
structrual_equal
"""
return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member


def deprecated(
method_name: str,
new_method_name: str,
):
"""A decorator to indicate that a method is deprecated
Parameters
----------
method_name : str
The name of the method to deprecate
new_method_name : str
The name of the new method to use instead
"""
import functools # pylint: disable=import-outside-toplevel
import warnings # pylint: disable=import-outside-toplevel

def _deprecate(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)

return _wrapper

return _deprecate
6 changes: 6 additions & 0 deletions python/tvm/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,36 @@
# under the License.
# pylint: disable=invalid-name
"""The legacy TVM parser """
from .ir.base import deprecated

# pylint: disable=import-outside-toplevel


@deprecated("tvm.parser.parse", "tvm.relay.parse")
def parse(*args, **kwargs):
"""Deprecated, use `tvm.relay.parse` instead"""
from tvm.relay import parse as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.parse_expr", "tvm.relay.parse_expr")
def parse_expr(*args, **kwargs):
"""Deprecated, use `tvm.relay.parse_expr` instead"""
from tvm.relay import parse_expr as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.fromtext", "tvm.relay.fromtext")
def fromtext(*args, **kwargs):
"""Deprecated, use `tvm.relay.fromtext` instead"""
from tvm.relay import fromtext as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.SpanCheck", "tvm.relay.SpanCheck")
def SpanCheck(*args, **kwargs):
"""Deprecated, use `tvm.relay.SpanCheck` instead"""
from tvm.relay import SpanCheck as _impl
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np # type: ignore

from tvm.ir import Range, Type
from tvm.ir.base import deprecated
from tvm.runtime import convert, ndarray
from tvm.target import Target

Expand Down Expand Up @@ -1427,6 +1428,26 @@ def ptr(dtype: str, storage_scope: str = "global") -> Var:
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


@deprecated("T.buffer_var", "T.Ptr")
def buffer_var(dtype: str, storage_scope: str = "global") -> Var:
"""The pointer declaration function.
Parameters
----------
dtype : str
The data type of the pointer.
storage_scope : str
The storage scope of the pointer.
Returns
-------
res : Var
The pointer.
"""
return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member


def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin
"""Compute the minimum value of two expressions.
Expand Down Expand Up @@ -1703,7 +1724,6 @@ def wrapped(*args, **kwargs):

broadcast = Broadcast
ramp = Ramp
buffer_var = ptr
fabs = abs
tvm_call_packed = call_packed
tvm_call_cpacked = call_cpacked
Expand Down
1 change: 0 additions & 1 deletion python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script Parser utils"""

import inspect
from types import FrameType
from typing import Any, Callable, Dict, List
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
from typing import Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer_decl, ptr
Expand Down Expand Up @@ -49,7 +50,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:

class BufferProxy:
"""Buffer proxy class for constructing tir buffer.
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer[].
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
"""

def __call__(
Expand Down Expand Up @@ -78,6 +79,7 @@ def __call__(
axis_separators=axis_separators,
)

@deprecated("T.Buffer(...)", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
Expand All @@ -88,14 +90,15 @@ def __getitem__(self, keys) -> Buffer:

class PtrProxy:
"""Ptr proxy class for constructing tir pointer.
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[].
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""

def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore

@deprecated("T.Ptr(...)", "T.Ptr(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,13 +1932,13 @@ class object that inherits from `Exception`.
class TestRemoveIf(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()
def before(A: T.Buffer[1, "int32"]):
def before(A: T.Buffer(1, "int32")):
if True:
A[0] = 42
else:
A[0] = 5
def expected(A: T.Buffer[1, "int32"]):
def expected(A: T.Buffer(1, "int32")):
A[0] = 42
"""
Expand Down
Loading

0 comments on commit 256bad7

Please sign in to comment.