Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] refactor #6734

Merged
merged 3 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""

from .utils import create_module, asscript, tir, module
from .parser import from_source
from .parser import from_source, create_module, asscript, tir, module
2 changes: 1 addition & 1 deletion python/tvm/script/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tvmscript"""
"""FFI APIs for tvm.script"""
import tvm._ffi

tvm._ffi._init_api("script", __name__)
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script Scope Emitter for TIR"""
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule


class ScopeEmitter:
"""Maintain the nodes and symbols of scopes"""
class ContextMaintainer:
"""Maintain all the necessary context info"""

def __init__(self, parser):
self.node_stack = [[]] # AST nodes of scopes
self.symbols = [dict()] # Symbols of scopes
# scope context
self.node_stack = [] # AST nodes of scopes
self.symbols = [] # symbols of scopes
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def new_scope(self):
""" Creating a new scope """
self.node_stack.append([])
def new_scope(self, nodes=None):
"""Creating a new scope"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
Expand All @@ -60,3 +69,6 @@ def lookup_symbol(self, name):
if name in symbols:
return symbols[name]
return None

def report_error(self, message):
self.parser.report_error(message)
100 changes: 59 additions & 41 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,127 +14,127 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script Parser Intrinsic Functions

IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins
"""
# pylint: disable=redefined-builtin
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
import tvm.tir
from .registry import register_intrin
from .registry import register
from .utils import get_param_list


class Intrin:
def __init__(self, intrin, stmt=False):
self.intrin = intrin
self.stmt = stmt

def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list):
return self.intrin(*arg_list)

@register_intrin()

@register
def bool(imm):
return tvm.tir.const(imm, "bool")


@register_intrin()
@register
def int8(imm):
return tvm.tir.const(imm, "int8")


@register_intrin()
@register
def int16(imm):
return tvm.tir.const(imm, "int16")


@register_intrin()
@register
def int32(imm):
return tvm.tir.const(imm, "int32")


@register_intrin()
@register
def int64(imm):
return tvm.tir.const(imm, "int64")


@register_intrin()
@register
def uint8(imm):
return tvm.tir.const(imm, "uint8")


@register_intrin()
@register
def uint16(imm):
return tvm.tir.const(imm, "uint16")


@register_intrin()
@register
def uint32(imm):
return tvm.tir.const(imm, "uint32")


@register_intrin()
@register
def uint64(imm):
return tvm.tir.const(imm, "uint64")


@register_intrin()
@register
def float8(imm):
return tvm.tir.const(imm, "float8")


@register_intrin()
@register
def float16(imm):
return tvm.tir.const(imm, "float16")


@register_intrin()
@register
def float32(imm):
return tvm.tir.const(imm, "float32")


@register_intrin()
@register
def float64(imm):
return tvm.tir.const(imm, "float64")


@register_intrin()
@register
def floordiv(x, y):
return tvm.tir.floordiv(x, y)


@register_intrin()
@register
def floormod(x, y):
return tvm.tir.floormod(x, y)


@register_intrin()
@register
def load(dtype, var, index, predicate=True):
return tvm.tir.Load(dtype, var, index, predicate)


@register_intrin()
@register
def cast(value, dtype):
return tvm.tir.Cast(dtype, value)


@register_intrin()
@register
def ramp(base, stride, lanes):
return tvm.tir.Ramp(base, stride, lanes)


@register_intrin()
@register
def broadcast(value, lanes):
return tvm.tir.Broadcast(value, lanes)


@register_intrin()
def evaluate(value):
return tvm.tir.Evaluate(value)


@register_intrin()
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)


@register_intrin()
@register
def iter_var(var, dom, iter_type, thread_tag):
iter_type = getattr(tvm.tir.IterVar, iter_type)
return tvm.tir.IterVar(dom, var, iter_type, thread_tag)


@register_intrin()
@register
def max(a, b): # pylint: disable=redefined-builtin
return tvm.tir.Max(a, b)

Expand All @@ -148,21 +148,39 @@ def get_axis(begin, end, iter_type):
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type])


@register_intrin()
@register
def range(begin, end):
return get_axis(begin, end, "data_par")


@register_intrin()
@register
def reduce_axis(begin, end):
return get_axis(begin, end, "reduce")


@register_intrin()
@register
def scan_axis(begin, end):
return get_axis(begin, end, "scan")


@register_intrin()
@register
def opaque_axis(begin, end):
return get_axis(begin, end, "opaque")


@register
class EvaluateIntrin(Intrin):
def __init__(self):
def evaluate(value):
return tvm.tir.Evaluate(value)

super().__init__(evaluate, stmt=True)


@register
class StoreIntrin(Intrin):
def __init__(self):
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)

super().__init__(store, stmt=True)
Loading