Skip to content

Commit

Permalink
[Hybrid script] Backend support (apache#2477)
Browse files Browse the repository at this point in the history
* a preliminary version is done?

* we no longer need the redundant hybrid/api.py

* support assert stmt

* cast supported

* intrin -> runtime; util is mainly in charge of compilation time

* assert statement

* fix python lint

* fix cpp lint

* on the way to module

* rollback .cc

* fix typo, no direct expose then

* @vinx13 ceil is added i guess?

* wip...

* temp commit

* fix import

* i preliminary version is done?

* on the way to build hybrid module

* nearly fixed...

* dumped python are equiv as original python

* on the way to bootstrap

* cpu bootstrap done

* bootstrap!

* fix lint

* fix doc

* resolve some review concerns

* support load/save

* fix lint

* thanks to xqdan fixed my typo

* fix build, make dump non-optional

* add vthread

* jesus why i added this
  • Loading branch information
were authored and AWS Neo committed Feb 20, 2019
1 parent 80943dd commit 82ed3c9
Show file tree
Hide file tree
Showing 17 changed files with 1,091 additions and 147 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/HybridDump.cmake)

add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
Expand Down
3 changes: 3 additions & 0 deletions cmake/modules/contrib/HybridDump.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message(STATUS "Build with contrib.hybriddump")
file(GLOB HYBRID_CONTRIB_SRC src/contrib/hybrid/*.cc)
list(APPEND COMPILER_SRCS ${HYBRID_CONTRIB_SRC})
14 changes: 14 additions & 0 deletions docs/langref/hybrid_script.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,20 @@ You can also do loop-thread bind by writing code like this:
a[tx] = b[tx]
Assert Statement
~~~~~~~~~~~~~~~~

Assert statement is supported, you can simply use it as it is in standard Python.

.. code-block:: python
assert cond, mesg
.. note::

``Assert`` is NOT a function call. Users are encouraged to use assert in the way
presented above --- condition followed by message. It fits both Python AST and HalideIR.

Keywords
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
Expand Down
25 changes: 20 additions & 5 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,25 @@ def get_binds(args, binds=None):
return binds, arg_list


def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt


def lower(sch,
args,
name="default_function",
Expand Down Expand Up @@ -337,11 +356,7 @@ def lower(sch,

# Phase 0
if isinstance(sch, schedule.Schedule):
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
stmt = form_body(sch)

for f in lower_phase0:
stmt = f(stmt)
Expand Down
75 changes: 72 additions & 3 deletions python/tvm/hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,77 @@
1. Users can write some preliminary versions of the computation patterns
have not been supported yet and verify it across the real execution and
python semantic emulation.
2. Developers can build HalideIR by writing Python code.
2. So far, it is a text format dedicated to HalideIR Phase 0. Refer tvm.lower
for more details. A larger ambition of this module is to support all levels of
HalideIR.
"""

from .api import script
from .parser import parse_python
# TODO(@were): Make this module more complete.
# 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR

from __future__ import absolute_import as _abs

from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body

from .module import HybridModule
from .parser import source_to_op
from .util import _pruned_source


def script(pyfunc):
"""Decorate a python function function as hybrid script.
The hybrid function support emulation mode and parsing to
the internal language IR.
Returns
-------
hybrid_func : function
A decorated hybrid script function.
"""
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
return source_to_op(src, func.__globals__, args)

from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value

return decorate(pyfunc, wrapped_func)


def build(sch, inputs, outputs, name="hybrid_func"):
"""Dump the corrent schedule to hybrid module
Parameters
----------
sch: Schedule
The schedule to be dumped
inputs: An array of Tensors or Vars
The inputs of the function body
outputs: An array of Tensors
The outputs of the function body
Returns
-------
module: HybridModule
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
"""

stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)

return HybridModule(src, name)


_init_api("tvm.hybrid")
43 changes: 0 additions & 43 deletions python/tvm/hybrid/api.py

This file was deleted.

27 changes: 27 additions & 0 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
from ..intrin import call_pure_intrin

#pylint: disable=redefined-builtin

Expand Down Expand Up @@ -104,3 +105,29 @@ def len(func_id, args):
except: #pylint: disable=bare-except
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
return _api.convert(args[0].shape[0])


def _cast(func_id, args):
_internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \
"Only one expression can be cast")
return _make.Cast(func_id, args[0])

float16 = float32 = float64 = _cast #pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name


def ceil_div(func_id, args):
_internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 2, "2 arguments expected for division!")
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) / b


def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args)
100 changes: 100 additions & 0 deletions python/tvm/hybrid/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Methods and data structures to support dumping HalideIR to Hybrid Script.
This allows users to do quick hack to generated HalideIR and cast it back to
TVM modules.
To enable this feature, you need to build with -DUSE_HYBRID_DUMP=ON.
"""

import ast
import imp

from ..contrib import util
from .util import _internal_assert
from .util import _is_tvm_arg_types
from .parser import source_to_op


class HybridModule(object):
"""The usage of Hybrid Module is very similar to conventional TVM module,
but conventional TVM module requires a function body which is already fully
lowered. This contradicts to the fact that Hybrid Module is originally a text
format for Phase 0 HalideIR. Thus, a totally separated module is defined."""


def __init__(self, src=None, name=None):
"""The constructor of this a hybrid module
Parameters
----------
src : str
The source code of this module
name : str
The name of this module
"""
self.src_ = self.name = self.func_ = self.root_ = None
if src is not None:
temp = util.tempdir()
dst = temp.relpath("script.py")
with open(dst, 'w') as f:
f.write("import tvm\n@tvm.hybrid.script\n%s" % src)

if name is not None:
self.name = name
self.load(dst)


def __call__(self, *args):
if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args)
return self.func_(*args)


def get_source(self):
return self.src_


def save(self, path):
if not path.endswith('.py'):
path = path + '.py'
with open(path, 'w') as f:
f.write(self.src_)


def load(self, path):
"""Load the module from a python file
Parameters
----------
path : str
Path to the given python file
"""
with open(path, 'r') as f:
self.src_ = f.read()

src = self.src_

class FindFunc(ast.NodeVisitor):
""" Find the function in module to be loaded module. """
#pylint: disable=invalid-name
def __init__(self):
self.name = None
self.root = None


def visit_FunctionDef(self, node):
_internal_assert(self.name is None, "For now, only one function supported!")
self.name = node.name
_internal_assert(self.root is None, "For now, only one function supported!")
self.root = node

root = ast.parse(src)
finder = FindFunc()
finder.visit(root)
_internal_assert(finder.name is not None and finder.root is not None, \
"No function found!")
if self.name is None:
self.name = finder.name
self.root_ = finder.root
py_module = imp.load_source(self.name, path)
self.func_ = getattr(py_module, self.name)
Loading

0 comments on commit 82ed3c9

Please sign in to comment.