Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] [HYBRID] Non-zero starting supported; Buffer AttrStmt add! #1330

Merged
merged 4 commits into from
Jun 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, tag, ext):
unroll = vectorize = parallel = _range #pylint: disable=invalid-name


def allocate(shape, dtype='float32'):
def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update docstring to include all arguments

"""Allocate a buffer with given shape

Parameters
Expand All @@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'):
The shape of the tensor to be allocated
dtype: string
The data type of the tensor
scope: string
The storage scope of the tensor

Returns
-------
Expand Down
101 changes: 72 additions & 29 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ast
import operator
import sys
from .util import make_nop, make_const_true, make_range_one, halide_imm_types
from .util import make_nop, halide_imm_types
from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage
from ..api import thread_axis
Expand Down Expand Up @@ -75,7 +75,8 @@ def __init__(self, args, usage, func_name=None):
self.args = args[:]
self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.buffers = {}
self.var_buffers = {} # Buffers formed by mutatble variables
self.alloc_buffers = {} # Buffers formed by allocate instructions
self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage
self.func_name = func_name # The name of the function to be lowered
Expand All @@ -87,19 +88,30 @@ def wrap_up_realize(self, node, body):
for key, val in self.usage.items():
if key in self.var_consts.keys():
continue
_, scope, _ = val
if scope == node:
_buf = self.buffers[key]
_, level, _ = val
if level == node:
if key in self.var_buffers.keys():
_buf = self.var_buffers[key]
_scope = 'global'
else:
_buf, _scope = self.alloc_buffers[key]
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_one = make_range_one()
_true = make_const_true()
body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body)
_true = _api.convert(True)
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
return body


def _check_id_a_buffer(self, s):
if s not in self._args.keys():
def _get_buffer_from_id(self, s):
if s not in self._args.keys() and s not in self.alloc_buffers.keys():
raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s)
if s in self._args.keys() and s in self.alloc_buffers.keys():
raise ValueError("%s, a buffer cannot be both argument and allocated!" % s)
if s in self._args.keys():
return self._args[s]
return self.alloc_buffers[s][0]



#pylint: disable=invalid-name, missing-docstring
Expand Down Expand Up @@ -138,8 +150,8 @@ def visit_Name(self, node):
if _id not in self.usage.keys():
raise ValueError("This id %s is expected to be a defined variable!" % _id)
# Buffer
if _id in self.buffers.keys():
_buf = self.buffers[_id]
if _id in self.var_buffers.keys():
_buf = self.var_buffers[_id]
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
# Compilation time constant
if _id not in self.var_consts.keys():
Expand All @@ -155,7 +167,9 @@ def visit_Assign(self, node):
if len(node.targets) != 1:
raise ValueError("So far only one-valued assignment is supported!")
lhs = node.targets[0]
rhs = _ir_pass.Simplify(self.visit(node.value))
rhs = self.visit(node.value)
if isinstance(rhs, _expr.Expr):
rhs = _ir_pass.Simplify(rhs)
if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later
lhs_ = lhs
Expand All @@ -166,25 +180,31 @@ def visit_Assign(self, node):
if decl == lhs_:
if lhs in self.var_consts.keys():
raise ValueError("BUG: A constant cannot be overwritten!")
if lhs in self.buffers.keys():
if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys():
raise ValueError("BUG: This value should not be defined before this point!")
if isinstance(rhs, tuple):
shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, scope)
return make_nop()
if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs
else:
self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
if lhs in self.var_consts.keys():
return make_nop()
else:
if lhs not in self.buffers.keys():
raise ValueError("BUG: This value should be defined before!")
return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
if lhs not in self.var_buffers.keys():
raise ValueError("BUG: This variable should be defined before!")
tgt = self.var_buffers[lhs]
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else:
lhs = self.visit(lhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later
self._check_id_a_buffer(lhs.name)
return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args)
buf = self._get_buffer_from_id(lhs.name)
return _make.Provide(buf.op, 0, rhs, lhs.args)


def visit_Index(self, node):
Expand All @@ -197,8 +217,7 @@ def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
array = node.value.id
self._check_id_a_buffer(array)
_buf = self._args[array]
_buf = self._get_buffer_from_id(array)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
elif isinstance(node.value, ast.Attribute):
if not isinstance(node.value.value, ast.Name):
Expand All @@ -211,8 +230,8 @@ def visit_Subscript(self, node):
#TODO: maybe support non-constant value later?
if not isinstance(args, (_expr.IntImm, _expr.UIntImm)):
raise ValueError("So far only constant shape access supported!")
self._check_id_a_buffer(node.value.value.id)
return self._args[node.value.value.id].shape[args.value]
buf = self._get_buffer_from_id(node.value.value.id)
return buf.shape[args.value]
else:
raise ValueError("Not supported yet!")

Expand Down Expand Up @@ -303,8 +322,30 @@ def visit_Call(self, node):
elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id == 'allocate':
#TODO: Support it later!
return make_nop()
if not isinstance(node.args[0], ast.Tuple):
raise ValueError("allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts)
for i in shape:
if not isinstance(i, _expr.Expr):
raise ValueError("The shape should be an expression")
if n > 1:
if not isinstance(node.args[1], ast.Str):
raise ValueError("The data type should be an string")
dtype = node.args[1].s
else:
dtype = 'float32'
if n > 2:
if not isinstance(node.args[2], ast.Str):
raise ValueError("The data type should be an string")
scope = node.args[2].s
else:
scope = 'global'
return (shape, dtype, scope)
elif func_id == 'max' or func_id == 'min':
if n != 2:
raise ValueError("Max/Min function should have 2 elements")
a, b = self.visit(node.args[0]), self.visit(node.args[1])
return getattr(_make, func_id.title())(a, b)
else:
raise ValueError("Function call not supported yet!")

Expand All @@ -317,8 +358,10 @@ def visit_For(self, node):
if iter_var is None:
if for_type is None:
raise ValueError("The loop bind function parse error!")
iter_var = _api.var(_name)
self.loops_above[_name] = iter_var
offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
offset = iter_var + low
self.loops_above[_name] = offset
else:
if for_type is not None:
raise ValueError("The loop iterating function parse error!")
Expand All @@ -328,7 +371,7 @@ def visit_For(self, node):
if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else:
res = _make.For(iter_var, low, ext, for_type, 0, _body)
res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body)
self.loops_above.pop(_name)
return res

Expand Down
10 changes: 0 additions & 10 deletions python/tvm/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ def make_nop():
return _make.Evaluate(_api.const(0, dtype='int32'))


def make_range_one():
"""Returns a [0, 1] range node in HalideIR."""
return _make.range_by_min_extent(0, 1)


def make_const_true():
"""Returns a constant True node in HalideIR."""
return _api.convert(True)


def _pruned_source(func):
"""Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n')
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/hybrid/var_decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def visit_Call(self, node):
#No function pointer supported so far
if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id")
if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range':
func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']:
raise ValueError("Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
Expand All @@ -64,7 +65,6 @@ def visit_Name(self, node):
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
loop = self.scope_level[-1]
usage.add(type(node.ctx))
self.status[node.id] = (decl, loop, usage)

Expand Down
Loading