Skip to content

Commit

Permalink
[TVMScript][UX] Introduce decorator for deprecation
Browse files Browse the repository at this point in the history
This PR introduces a decorator `tvm.ir.base.deprecated`, which emits a
deprecation warning if an outdated API is used, but preserves backward
compatibility by still allowing the API to be used.

For example, currently the preferred way of TIR buffer declaration in
function signature is:

```python
def example(
  A: T.Buffer(...),  # legacy behavior is `T.Buffer[...]`
): ...
```

With this decorator, if a user writes `T.Buffer[...]`, the parser will
still function properly, but emits a warning that guides the user to
adopt `T.Buffer(...)` if possible.

While there is no breaking change at all in this PR, we believe this
is useful to help users upgrade before any breaking change eventually
takes place.
  • Loading branch information
junrushao committed Feb 9, 2023
1 parent 45a92df commit 9b1846d
Show file tree
Hide file tree
Showing 176 changed files with 3,563 additions and 3,524 deletions.
2 changes: 1 addition & 1 deletion apps/pt_tvmdsoop/tests/test_as_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
@tvm.script.ir_module
class ModuleGPU:
@T.prim_func
def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:
def main(A: T.Buffer(8, "float32"), B: T.Buffer(8, "float32")) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i_0 in T.thread_binding(2, thread="blockIdx.x"):
for i_2 in T.thread_binding(2, thread="threadIdx.x"):
Expand Down
8 changes: 4 additions & 4 deletions apps/pt_tvmdsoop/tests/test_boolean_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_tensor_boolean_operation():
@as_torch
@T.prim_func
def negate_tvmscript(
X: T.Buffer[(8, 8), "bool"],
Y: T.Buffer[(8, 8), "float32"],
Z: T.Buffer[(8, 8), "bool"],
U: T.Buffer[(8, 8), "float32"],
X: T.Buffer((8, 8), "bool"),
Y: T.Buffer((8, 8), "float32"),
Z: T.Buffer((8, 8), "bool"),
U: T.Buffer((8, 8), "float32"),
) -> None:
for i, j in T.grid(8, 8):
with T.block():
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class AssignDocNode : public StmtDocNode {
/*!
* \brief The right hand side of the assignment.
*
* If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]`
* If null, this doc represents declaration, e.g. `A: T.Buffer((1,2))`
* */
Optional<ExprDoc> rhs;
/*! \brief The type annotation of this assignment. */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*
* \code{.py}
* @T.prim_func
* def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* for i in T.serial(0, 16,
* annotations={"software_pipeline_stage": [0, 1],
Expand All @@ -601,7 +601,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*
* \code{.py}
* @T.prim_func
* def after_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None:
* def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
* for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
* with T.block():
* T.reads([A[tx, 0:16]])
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,34 @@ def structural_hash(node, map_free_vars=False):
structrual_equal
"""
return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member


def deprecated(
method_name: str,
new_method_name: str,
):
"""A decorator to indicate that a method is deprecated
Parameters
----------
method_name : str
The name of the method to deprecate
new_method_name : str
The name of the new method to use instead
"""
import functools # pylint: disable=import-outside-toplevel
import warnings # pylint: disable=import-outside-toplevel

def _deprecate(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)

return _wrapper

return _deprecate
40 changes: 20 additions & 20 deletions python/tvm/micro/contrib/stm32/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Module container of STM32 code generator."""

from .emitter import CodeEmitter, get_input_tensor_name, get_output_tensor_name
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Module container of STM32 code generator."""

from .emitter import CodeEmitter, get_input_tensor_name, get_output_tensor_name
6 changes: 6 additions & 0 deletions python/tvm/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,36 @@
# under the License.
# pylint: disable=invalid-name
"""The legacy TVM parser """
from .ir.base import deprecated

# pylint: disable=import-outside-toplevel


@deprecated("tvm.parser.parse", "tvm.relay.parse")
def parse(*args, **kwargs):
"""Deprecated, use `tvm.relay.parse` instead"""
from tvm.relay import parse as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.parse_expr", "tvm.relay.parse_expr")
def parse_expr(*args, **kwargs):
"""Deprecated, use `tvm.relay.parse_expr` instead"""
from tvm.relay import parse_expr as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.fromtext", "tvm.relay.fromtext")
def fromtext(*args, **kwargs):
"""Deprecated, use `tvm.relay.fromtext` instead"""
from tvm.relay import fromtext as _impl

return _impl(*args, **kwargs)


@deprecated("tvm.parser.SpanCheck", "tvm.relay.SpanCheck")
def SpanCheck(*args, **kwargs):
"""Deprecated, use `tvm.relay.SpanCheck` instead"""
from tvm.relay import SpanCheck as _impl
Expand Down
1 change: 0 additions & 1 deletion python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script Parser utils"""

import inspect
from types import FrameType
from typing import Any, Callable, Dict, List
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
from typing import Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer_decl, ptr
Expand Down Expand Up @@ -49,7 +50,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:

class BufferProxy:
"""Buffer proxy class for constructing tir buffer.
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer[].
Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer().
"""

def __call__(
Expand Down Expand Up @@ -78,6 +79,7 @@ def __call__(
axis_separators=axis_separators,
)

@deprecated("T.Buffer(...)", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
Expand All @@ -88,14 +90,15 @@ def __getitem__(self, keys) -> Buffer:

class PtrProxy:
"""Ptr proxy class for constructing tir pointer.
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[].
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""

def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore

@deprecated("T.Ptr(...)", "T.Ptr(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,13 +1932,13 @@ class object that inherits from `Exception`.
class TestRemoveIf(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()
def before(A: T.Buffer[1, "int32"]):
def before(A: T.Buffer(1, "int32")):
if True:
A[0] = 42
else:
A[0] = 5
def expected(A: T.Buffer[1, "int32"]):
def expected(A: T.Buffer(1, "int32")):
A[0] = 42
"""
Expand Down
Loading

0 comments on commit 9b1846d

Please sign in to comment.