Skip to content

Commit

Permalink
[TIR][Hybrid] Hybrid Script Support for TIR (apache#6227)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored Aug 10, 2020
1 parent 8bb99fb commit 87d6ccd
Show file tree
Hide file tree
Showing 15 changed files with 3,114 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
# tvm.parser
from . import parser

# tvm tir hybrid script
from . import hybrid

# others
from . import arith

Expand Down
20 changes: 20 additions & 0 deletions python/tvm/hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hybrid Script APIs of TVM Python Package, aimed to support TIR"""

from .utils import create_module, ashybrid, script
from .parser import from_source
21 changes: 21 additions & 0 deletions python/tvm/hybrid/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.hybrid"""
import tvm._ffi


tvm._ffi._init_api("tir.hybrid", __name__)
136 changes: 136 additions & 0 deletions python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hybrid Script Parser Intrinsic Functions
IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins
"""
# pylint: disable=redefined-builtin
import tvm.tir
from .registry import register_intrin


@register_intrin
def bool(imm):
return tvm.tir.const(imm.value, "bool")


@register_intrin
def int8(imm):
return tvm.tir.const(imm.value, "int8")


@register_intrin
def int16(imm):
return tvm.tir.const(imm.value, "int16")


@register_intrin
def int32(imm):
return tvm.tir.const(imm.value, "int32")


@register_intrin
def int64(imm):
return tvm.tir.const(imm.value, "int64")


@register_intrin
def uint8(imm):
return tvm.tir.const(imm.value, "uint8")


@register_intrin
def uint16(imm):
return tvm.tir.const(imm.value, "uint16")


@register_intrin
def uint32(imm):
return tvm.tir.const(imm.value, "uint32")


@register_intrin
def uint64(imm):
return tvm.tir.const(imm.value, "uint64")


@register_intrin
def float8(imm):
return tvm.tir.const(imm.value, "float8")


@register_intrin
def float16(imm):
return tvm.tir.const(imm.value, "float16")


@register_intrin
def float32(imm):
return tvm.tir.const(imm.value, "float32")


@register_intrin
def float64(imm):
return tvm.tir.const(imm.value, "float64")


@register_intrin
def floordiv(x, y):
return tvm.tir.floordiv(x, y)


@register_intrin
def floormod(x, y):
return tvm.tir.floormod(x, y)


@register_intrin
def load(dtype, var, index, predicate=True):
return tvm.tir.Load(dtype, var, index, predicate)


@register_intrin
def cast(dtype, value):
return tvm.tir.Cast(dtype, value)


@register_intrin
def ramp(base, stride, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Ramp(base, stride, lanes)


@register_intrin
def broadcast(value, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Broadcast(value, lanes)


@register_intrin
def evaluate(value):
return tvm.tir.Evaluate(value)


@register_intrin
def store(var, index, value, predicate=True):
return tvm.tir.Store(var, value, index, predicate)


@register_intrin
def iter_var(var, dom, iter_type, thread_tag):
iter_type = getattr(tvm.tir.IterVar, iter_type)
return tvm.tir.IterVar(dom, var, iter_type, thread_tag)
50 changes: 50 additions & 0 deletions python/tvm/hybrid/meta_unparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unparse meta AST node into a dict"""
# pylint: disable=invalid-name

from typed_ast import ast3 as ast


class MetaUnparser(ast.NodeVisitor):
"""Python AST Visitor to unparse meta AST node into a dict"""

def visit_Dict(self, node):
keys = [self.visit(key) for key in node.keys]
values = [self.visit(value) for value in node.values]
return dict(zip(keys, values))

def visit_Tuple(self, node):
return tuple(self.visit(element) for element in node.elts)

def visit_List(self, node):
return [self.visit(element) for element in node.elts]

def visit_keyword(self, node):
return node.arg, self.visit(node.value)

def visit_NameConstant(self, node):
return node.value

def visit_Constant(self, node):
return node.value

def visit_Num(self, node):
return node.n

def visit_Str(self, node):
return node.s
Loading

0 comments on commit 87d6ccd

Please sign in to comment.