Skip to content

Commit

Permalink
[TIR][Hybrid] Hybrid Script Improvement (apache#6507)
Browse files Browse the repository at this point in the history
* [TIR][Hybrid] update

* [TIR][Hybrid] python formatting
  • Loading branch information
spectrometerHBH authored and trevor-m committed Sep 18, 2020
1 parent ae9a3da commit b5bd021
Show file tree
Hide file tree
Showing 10 changed files with 1,207 additions and 1,229 deletions.
3 changes: 1 addition & 2 deletions python/tvm/hybrid/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@
"""FFI APIs for tvm.hybrid"""
import tvm._ffi


tvm._ffi._init_api("tir.hybrid", __name__)
tvm._ffi._init_api("hybrid", __name__)
108 changes: 70 additions & 38 deletions python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,114 +23,146 @@
from .registry import register_intrin


@register_intrin
@register_intrin()
def bool(imm):
return tvm.tir.const(imm.value, "bool")
return tvm.tir.const(imm, "bool")


@register_intrin
@register_intrin()
def int8(imm):
return tvm.tir.const(imm.value, "int8")
return tvm.tir.const(imm, "int8")


@register_intrin
@register_intrin()
def int16(imm):
return tvm.tir.const(imm.value, "int16")
return tvm.tir.const(imm, "int16")


@register_intrin
@register_intrin()
def int32(imm):
return tvm.tir.const(imm.value, "int32")
return tvm.tir.const(imm, "int32")


@register_intrin
@register_intrin()
def int64(imm):
return tvm.tir.const(imm.value, "int64")
return tvm.tir.const(imm, "int64")


@register_intrin
@register_intrin()
def uint8(imm):
return tvm.tir.const(imm.value, "uint8")
return tvm.tir.const(imm, "uint8")


@register_intrin
@register_intrin()
def uint16(imm):
return tvm.tir.const(imm.value, "uint16")
return tvm.tir.const(imm, "uint16")


@register_intrin
@register_intrin()
def uint32(imm):
return tvm.tir.const(imm.value, "uint32")
return tvm.tir.const(imm, "uint32")


@register_intrin
@register_intrin()
def uint64(imm):
return tvm.tir.const(imm.value, "uint64")
return tvm.tir.const(imm, "uint64")


@register_intrin
@register_intrin()
def float8(imm):
return tvm.tir.const(imm.value, "float8")
return tvm.tir.const(imm, "float8")


@register_intrin
@register_intrin()
def float16(imm):
return tvm.tir.const(imm.value, "float16")
return tvm.tir.const(imm, "float16")


@register_intrin
@register_intrin()
def float32(imm):
return tvm.tir.const(imm.value, "float32")
return tvm.tir.const(imm, "float32")


@register_intrin
@register_intrin()
def float64(imm):
return tvm.tir.const(imm.value, "float64")
return tvm.tir.const(imm, "float64")


@register_intrin
@register_intrin()
def floordiv(x, y):
return tvm.tir.floordiv(x, y)


@register_intrin
@register_intrin()
def floormod(x, y):
return tvm.tir.floormod(x, y)


@register_intrin
@register_intrin()
def load(dtype, var, index, predicate=True):
return tvm.tir.Load(dtype, var, index, predicate)


@register_intrin
def cast(dtype, value):
@register_intrin()
def cast(value, dtype):
return tvm.tir.Cast(dtype, value)


@register_intrin
@register_intrin()
def ramp(base, stride, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Ramp(base, stride, lanes)


@register_intrin
@register_intrin()
def broadcast(value, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Broadcast(value, lanes)


@register_intrin
@register_intrin()
def evaluate(value):
return tvm.tir.Evaluate(value)


@register_intrin
@register_intrin()
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)


@register_intrin
@register_intrin()
def iter_var(var, dom, iter_type, thread_tag):
iter_type = getattr(tvm.tir.IterVar, iter_type)
return tvm.tir.IterVar(dom, var, iter_type, thread_tag)


@register_intrin()
def max(a, b): # pylint: disable=redefined-builtin
return tvm.tir.Max(a, b)


def get_axis(begin, end, iter_type):
ana = tvm.arith.Analyzer()
extent = ana.simplify(end - begin)
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type])


@register_intrin()
def range(begin, end):
return get_axis(begin, end, "data_par")


@register_intrin()
def reduce_axis(begin, end):
return get_axis(begin, end, "reduce")


@register_intrin()
def scan_axis(begin, end):
return get_axis(begin, end, "scan")


@register_intrin()
def opaque_axis(begin, end):
return get_axis(begin, end, "opaque")
Loading

0 comments on commit b5bd021

Please sign in to comment.