Skip to content

Commit

Permalink
[IRBuilder] Misc fix-ups for the python binding (apache#39)
Browse files Browse the repository at this point in the history
* [IRBuilder] Misc fix-ups for the python binding

* addon
  • Loading branch information
junrushao committed Jun 10, 2022
1 parent 9bc4557 commit f15998f
Show file tree
Hide file tree
Showing 18 changed files with 108 additions and 148 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/builder/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""FFI APIs for tvm.script.builder"""
import tvm._ffi

tvm._ffi._init_api("script.builder", __name__)
tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access
31 changes: 17 additions & 14 deletions python/tvm/script/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,47 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script IR Builder"""
from typing import List
from tvm._ffi import register_object as _register_object
from .frame import Frame
from typing import List, TypeVar

from tvm._ffi import register_object as _register_object
from tvm.runtime import Object

from . import _ffi_api

from typing import TypeVar
from .frame import Frame


@_register_object("script.builder.Builder")
class Builder(Object):
def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.Builder)
self.__init_handle_by_constructor__(
_ffi_api.Builder # pylint: disable=no-member # type: ignore
)

def __enter__(self) -> "Builder":
_ffi_api.BuilderEnter(self)
_ffi_api.BuilderEnter(self) # pylint: disable=no-member # type: ignore
return self

def __exit__(self, ptype, value, trace) -> None:
_ffi_api.BuilderExit(self)
def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.BuilderExit(self) # pylint: disable=no-member # type: ignore

@staticmethod
def current(self) -> "Builder":
return _ffi_api.BuilderCurrent(self)
def current() -> "Builder":
return _ffi_api.BuilderCurrent() # pylint: disable=no-member # type: ignore

def get(self) -> Frame:
return _ffi_api.BuilderGet(self)
return _ffi_api.BuilderGet(self) # pylint: disable=no-member # type: ignore


DefType = TypeVar("DefType", bound=Object)


def def_(name: str, var: DefType) -> DefType:
return _ffi_api.Def(name, var)
return _ffi_api.Def(name, var) # pylint: disable=no-member # type: ignore


def def_many(names: List[str], vars: List[DefType]) -> List[DefType]:
def def_many(
names: List[str],
vars: List[DefType], # pylint: disable=redefine-builtin
) -> List[DefType]:
assert len(names) == len(vars)
return [def_(name, var) for name, var in zip(names, vars)]
11 changes: 6 additions & 5 deletions python/tvm/script/builder/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
"""TVM Script Frames"""
from tvm._ffi import register_object as _register_object

from tvm.runtime import Object

from . import _ffi_api
Expand All @@ -25,14 +24,16 @@
@_register_object("script.builder.Frame")
class Frame(Object):
def __enter__(self) -> "Frame":
_ffi_api.FrameEnter(self)
_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
return self

def __exit__(self, ptype, value, trace) -> None:
_ffi_api.FrameExit(self)
def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.FrameExit(self) # pylint: disable=no-member # type: ignore


@_register_object("script.builder.IRModuleFrame")
class IRModuleFrame(Frame):
def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.IRModuleFrame)
self.__init_handle_by_constructor__(
_ffi_api.IRModuleFrame # pylint: disable=no-member # type: ignore
)
14 changes: 7 additions & 7 deletions python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
# pylint: disable=unused-import
"""Namespace for the TVMScript TIR Builder API."""

from . import axis
from .base import TIRFrame
from .block_frame import block
from .for_frame import (
ForFrame,
serial,
grid,
parallel,
vectorized,
unroll,
serial,
thread_binding,
grid,
unroll,
vectorized,
)
from .prim_func_frame import prim_func, arg
from .block_frame import block
from .prim_func_frame import arg, prim_func
from .var import Buffer
from . import axis
6 changes: 2 additions & 4 deletions python/tvm/script/builder/tir/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.script.builder"""
"""FFI APIs for tvm.script.builder.tir"""
import tvm._ffi

from .. import _ffi_api as _base_ffi_api

tvm._ffi._init_api("script.builder.tir", __name__)
tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access
9 changes: 5 additions & 4 deletions python/tvm/script/builder/tir/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@
# under the License.
"""TVM Script TIR Axis"""

from . import _ffi_api
from tvm.ir import Range
from tvm.tir import IterVar

from . import _ffi_api


def spatial(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
return _ffi_api.AxisSpatial(dom, binding, dtype)
return _ffi_api.AxisSpatial(dom, binding, dtype) # pylint: disable=no-member # type: ignore


def reduce(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
return _ffi_api.AxisReduce(dom, binding, dtype)
return _ffi_api.AxisReduce(dom, binding, dtype) # pylint: disable=no-member # type: ignore


def remap(kinds, bindings, dtype="int32") -> IterVar:
return _ffi_api.AxisRemap(kinds, bindings, dtype)
return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore
3 changes: 1 addition & 2 deletions python/tvm/script/builder/tir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
"""TVM Script TIR Frame"""
from tvm._ffi import register_object as _register_object

from . import _ffi_api
from ..frame import Frame


@_register_object("script.builder.tir.TIRFrame")
class TIRFrame(Frame):
pass
...
7 changes: 3 additions & 4 deletions python/tvm/script/builder/tir/block_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
# under the License.
"""TVM Script TIR Block Frame"""
from tvm._ffi import register_object as _register_object
from .base import TIRFrame


from . import _ffi_api
from .base import TIRFrame


@_register_object("script.builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
pass
...


def block(name) -> BlockFrame:
return _ffi_api.BlockFrame(name)
return _ffi_api.BlockFrame(name) # pylint: disable=no-member # type: ignore
32 changes: 17 additions & 15 deletions python/tvm/script/builder/tir/for_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,44 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR For Frame"""
from tvm._ffi import register_object as _register_object
from typing import List

from tvm._ffi import register_object as _register_object
from tvm.tir import Var

from . import _ffi_api
from ._ffi_api import _base_ffi_api
from .. import _ffi_api as _base_ffi_api
from .base import TIRFrame
from typing import List


@_register_object("script.builder.tir.ForFrame")
class ForFrame(TIRFrame):
def __enter__(self) -> List[Var]:
_base_ffi_api.FrameEnter(self)
_base_ffi_api.FrameEnter(self) # pylint: disable=no-member # type: ignore
return self.vars


def serial(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Serial(min_val, extent, attrs)
def serial(start, stop, annotations) -> ForFrame:
return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore


def parallel(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Parallel(min_val, extent, attrs)
def parallel(start, stop, annotations) -> ForFrame:
return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore


def vectorized(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Vectorized(min_val, extent, attrs)
def vectorized(start, stop, annotations) -> ForFrame:
return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore


def unroll(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Unroll(min_val, extent, attrs)
def unroll(start, stop, annotations) -> ForFrame:
return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore


def thread_binding(min_val, extent, attrs) -> ForFrame:
return _ffi_api.ThreadBinding(min_val, extent, attrs)
def thread_binding(start, stop, thread, annotations) -> ForFrame:
return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
start, stop, thread, annotations
)


def grid(*extents) -> ForFrame:
return _ffi_api.Grid(extents)
return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore
16 changes: 7 additions & 9 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,24 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Prim Func Frame"""
from tvm._ffi import register_object as _register_object
from typing import Union

from tvm.tir.expr import Var
from tvm._ffi import register_object as _register_object
from tvm.tir.buffer import Buffer

from tvm.tir.expr import Var

from . import _ffi_api
from .base import TIRFrame

from typing import Union


@_register_object("script.builder.tir.PrimFuncFrame")
class PrimFuncFrame(TIRFrame):
pass
...


def prim_func(name) -> PrimFuncFrame:
return _ffi_api.PrimFuncFrame(name)
return _ffi_api.PrimFuncFrame(name) # pylint: disable=no-member # type: ignore


def arg(name, arg) -> Union[Var, Buffer]:
return _ffi_api.Arg(name, arg)
def arg(name, obj) -> Union[Var, Buffer]:
return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore
13 changes: 8 additions & 5 deletions python/tvm/script/builder/tir/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Buffer"""
from tvm._ffi import register_object as _register_object

from tvm.tir.buffer import Buffer
from tvm import tir

from . import _ffi_api


def Buffer(shape, dtype, name="buffer", storage_scope="") -> Buffer:
return _ffi_api.Buffer(shape, dtype, name, storage_scope)
def Buffer( # pylint: disable=invalid-name
shape,
dtype,
name="buffer",
storage_scope="",
) -> tir.Buffer:
return _ffi_api.Buffer(shape, dtype, name, storage_scope) # pylint: disable=no-member # type: ignore
9 changes: 1 addition & 8 deletions src/script/builder/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,9 @@ namespace tvm {
namespace script {
namespace builder {

void FrameNode::EnterWithScope() {
LOG(INFO) << "EnterWithScope: " << this->GetTypeKey();
// Push to the current builder
Builder::Current()->frames.push_back(GetRef<Frame>(this));
}
void FrameNode::EnterWithScope() { Builder::Current()->frames.push_back(GetRef<Frame>(this)); }

void FrameNode::ExitWithScope() {
LOG(INFO) << "ExitWithScope: " << this->GetTypeKey();
for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
(*it)();
}
Expand Down Expand Up @@ -60,9 +55,7 @@ void IRModuleFrameNode::ExitWithScope() {

TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);

TVM_REGISTER_GLOBAL("script.builder.FrameEnter").set_body_method<Frame>(&FrameNode::EnterWithScope);

TVM_REGISTER_GLOBAL("script.builder.FrameExit").set_body_method<Frame>(&FrameNode::ExitWithScope);

} // namespace builder
Expand Down
30 changes: 0 additions & 30 deletions src/script/builder/tir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,6 @@ namespace tir {

TVM_REGISTER_NODE_TYPE(TIRFrameNode);

void TestPOC() {
namespace T = tvm::script::builder::tir;
using namespace ::tvm::tir;

With<Builder> builder;
{
With<PrimFuncFrame> _{T::PrimFunc_("main")};
Buffer A = T::Arg("A", T::Buffer_({128, 128, 128}, DataType::Float(32)));
Buffer B = T::Arg("B", T::Buffer_({128, 128, 128}, DataType::Float(32)));
{
With<ForFrame> _{T::Grid({128, 128, 128})};
Var i = Def("i", _()->vars[0]);
Var j = Def("j", _()->vars[1]);
Var k = Def("k", _()->vars[2]);
{
With<BlockFrame> _{T::Block_("block")};
IterVar vi = Def("vi", T::axis::Spatial(Range(0, 128), i));
IterVar vj = Def("vj", T::axis::Spatial(Range(0, 128), j));
IterVar vk = Def("vk", T::axis::Reduce(Range(0, 128), k));
}
LOG(INFO) << "ForFrame:\n" << _()->stmts;
}
LOG(INFO) << "PrimFuncFrame:\n" << _()->stmts;
}
PrimFunc func = builder()->Get<PrimFunc>();
LOG(INFO) << "func:\n" << AsTVMScript(func);
}

TVM_REGISTER_GLOBAL("test_poc").set_body_typed(TestPOC);

} // namespace tir
} // namespace builder
} // namespace script
Expand Down
4 changes: 0 additions & 4 deletions src/script/builder/tir/block_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,9 @@ Array<tvm::tir::IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType
} // namespace axis

TVM_REGISTER_NODE_TYPE(BlockFrameNode);

TVM_REGISTER_GLOBAL("script.builder.tir.BlockFrame").set_body_typed(Block_);

TVM_REGISTER_GLOBAL("script.builder.tir.AxisSpatial").set_body_typed(axis::Spatial);

TVM_REGISTER_GLOBAL("script.builder.tir.AxisReduce").set_body_typed(axis::Reduce);

TVM_REGISTER_GLOBAL("script.builder.tir.AxisRemap").set_body_typed(axis::Remap);

} // namespace tir
Expand Down
Loading

0 comments on commit f15998f

Please sign in to comment.