Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hybrid Script] Inter-function call supported! #2287

Merged
merged 24 commits into from
Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/tvm/hybrid/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@ def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
parser = parse_python(src, args)
parser = parse_python(src, func.__globals__, args)

input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)

op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]

return res[0] if len(res) == 1 else res

intersect = _enter_hybrid_runtime(func)
Expand Down
93 changes: 93 additions & 0 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
were marked this conversation as resolved.
Show resolved Hide resolved

import ast
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from ..container import Array
from ..ir_pass import Equal
from ..stmt import For
from .util import _internal_assert

#pylint: disable=redefined-builtin

LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
}

def _range(annotation, args):
"""Handling TVM loop types"""
n = len(args)
if n == 1:
low, ext = _api.const(0, dtype='int32'), args[0]
were marked this conversation as resolved.
Show resolved Hide resolved
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
were marked this conversation as resolved.
Show resolved Hide resolved
if not Equal(low, _api.const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
return iter_var, low, ext, for_type


range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name


def bind(func_id, args):
were marked this conversation as resolved.
Show resolved Hide resolved
"""Handling TVM thread binding"""
_internal_assert(func_id == "bind", "This function cannot be directly invoked!")
were marked this conversation as resolved.
Show resolved Hide resolved
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], ast.Str), \
"A loop bind's first argument should be a string!")
iter_var = thread_axis(args[0])
low, ext = _api.const(0, dtype='int32'), args[1]
for_type = None
return iter_var, low, ext, for_type


def _math_intrin(func_id, args):
from .. import intrin
return getattr(intrin, func_id)(*args)

sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name


def _min_max(func_id, args):
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1])


min = max = _min_max #pylint: disable=invalid-name


def _allocate_tensor(func_id, args):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n = len(args)
were marked this conversation as resolved.
Show resolved Hide resolved
_internal_assert(isinstance(_api.convert(args[0]), Array), \
"allocate's first argument should be a tuple of shape!")
shape = args[0]
for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
_internal_assert(isinstance(args[1], str),
"The data type should be an str")
_internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
"The data type should be either int or float!")
dtype = args[1]
else:
dtype = 'float32'
were marked this conversation as resolved.
Show resolved Hide resolved
if n > 2:
_internal_assert(isinstance(args[2], str), \
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = args[2]
else:
scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope)

output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
15 changes: 1 addition & 14 deletions python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""

import numpy
from ..stmt import For

class _range(object):
"""Base class of the loop ranges in hybrid script"""
Expand Down Expand Up @@ -102,15 +101,3 @@ def sigmoid(x):
'sigmoid' : sigmoid,
'popcount' : popcount
}


LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
'bind' : None
}


MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
Loading