Skip to content

Commit

Permalink
[TVMScript] refactor (#6734)
Browse files Browse the repository at this point in the history
* [TVMScript] refactor

* [TVMScript] pylint

* [TVMScript] pylint
  • Loading branch information
spectrometerHBH authored Oct 22, 2020
1 parent f65e320 commit 129333b
Show file tree
Hide file tree
Showing 10 changed files with 829 additions and 909 deletions.
3 changes: 1 addition & 2 deletions python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""

from .utils import create_module, asscript, tir, module
from .parser import from_source
from .parser import from_source, create_module, asscript, tir, module
2 changes: 1 addition & 1 deletion python/tvm/script/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tvmscript"""
"""FFI APIs for tvm.script"""
import tvm._ffi

tvm._ffi._init_api("script", __name__)
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script Scope Emitter for TIR"""
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule


class ScopeEmitter:
"""Maintain the nodes and symbols of scopes"""
class ContextMaintainer:
"""Maintain all the necessary context info"""

def __init__(self, parser):
self.node_stack = [[]] # AST nodes of scopes
self.symbols = [dict()] # Symbols of scopes
# scope context
self.node_stack = [] # AST nodes of scopes
self.symbols = [] # symbols of scopes
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def new_scope(self):
""" Creating a new scope """
self.node_stack.append([])
def new_scope(self, nodes=None):
"""Creating a new scope"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
Expand All @@ -60,3 +69,6 @@ def lookup_symbol(self, name):
if name in symbols:
return symbols[name]
return None

def report_error(self, message):
self.parser.report_error(message)
100 changes: 59 additions & 41 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,127 +14,127 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script Parser Intrinsic Functions
IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins
"""
# pylint: disable=redefined-builtin
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import tvm.tir
from .registry import register_intrin
from .registry import register
from .utils import get_param_list


class Intrin:
def __init__(self, intrin, stmt=False):
self.intrin = intrin
self.stmt = stmt

def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list):
return self.intrin(*arg_list)

@register_intrin()

@register
def bool(imm):
return tvm.tir.const(imm, "bool")


@register_intrin()
@register
def int8(imm):
return tvm.tir.const(imm, "int8")


@register_intrin()
@register
def int16(imm):
return tvm.tir.const(imm, "int16")


@register_intrin()
@register
def int32(imm):
return tvm.tir.const(imm, "int32")


@register_intrin()
@register
def int64(imm):
return tvm.tir.const(imm, "int64")


@register_intrin()
@register
def uint8(imm):
return tvm.tir.const(imm, "uint8")


@register_intrin()
@register
def uint16(imm):
return tvm.tir.const(imm, "uint16")


@register_intrin()
@register
def uint32(imm):
return tvm.tir.const(imm, "uint32")


@register_intrin()
@register
def uint64(imm):
return tvm.tir.const(imm, "uint64")


@register_intrin()
@register
def float8(imm):
return tvm.tir.const(imm, "float8")


@register_intrin()
@register
def float16(imm):
return tvm.tir.const(imm, "float16")


@register_intrin()
@register
def float32(imm):
return tvm.tir.const(imm, "float32")


@register_intrin()
@register
def float64(imm):
return tvm.tir.const(imm, "float64")


@register_intrin()
@register
def floordiv(x, y):
return tvm.tir.floordiv(x, y)


@register_intrin()
@register
def floormod(x, y):
return tvm.tir.floormod(x, y)


@register_intrin()
@register
def load(dtype, var, index, predicate=True):
return tvm.tir.Load(dtype, var, index, predicate)


@register_intrin()
@register
def cast(value, dtype):
return tvm.tir.Cast(dtype, value)


@register_intrin()
@register
def ramp(base, stride, lanes):
return tvm.tir.Ramp(base, stride, lanes)


@register_intrin()
@register
def broadcast(value, lanes):
return tvm.tir.Broadcast(value, lanes)


@register_intrin()
def evaluate(value):
return tvm.tir.Evaluate(value)


@register_intrin()
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)


@register_intrin()
@register
def iter_var(var, dom, iter_type, thread_tag):
iter_type = getattr(tvm.tir.IterVar, iter_type)
return tvm.tir.IterVar(dom, var, iter_type, thread_tag)


@register_intrin()
@register
def max(a, b): # pylint: disable=redefined-builtin
return tvm.tir.Max(a, b)

Expand All @@ -148,21 +148,39 @@ def get_axis(begin, end, iter_type):
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type])


@register_intrin()
@register
def range(begin, end):
return get_axis(begin, end, "data_par")


@register_intrin()
@register
def reduce_axis(begin, end):
return get_axis(begin, end, "reduce")


@register_intrin()
@register
def scan_axis(begin, end):
return get_axis(begin, end, "scan")


@register_intrin()
@register
def opaque_axis(begin, end):
return get_axis(begin, end, "opaque")


@register
class EvaluateIntrin(Intrin):
def __init__(self):
def evaluate(value):
return tvm.tir.Evaluate(value)

super().__init__(evaluate, stmt=True)


@register
class StoreIntrin(Intrin):
def __init__(self):
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)

super().__init__(store, stmt=True)
Loading

0 comments on commit 129333b

Please sign in to comment.