Skip to content

Commit

Permalink
[FRONTEND] [HYBRID] Non-zero starting supported; Buffer AttrStmt add! (
Browse files Browse the repository at this point in the history
  • Loading branch information
were authored and tqchen committed Jun 25, 2018
1 parent f9d8427 commit f721a64
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 75 deletions.
4 changes: 3 additions & 1 deletion python/tvm/hybrid/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, tag, ext):
unroll = vectorize = parallel = _range #pylint: disable=invalid-name


def allocate(shape, dtype='float32'):
def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
"""Allocate a buffer with given shape
Parameters
Expand All @@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'):
The shape of the tensor to be allocated
dtype: string
The data type of the tensor
scope: string
The storage scope of the tensor
Returns
-------
Expand Down
101 changes: 72 additions & 29 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ast
import operator
import sys
from .util import make_nop, make_const_true, make_range_one, halide_imm_types
from .util import make_nop, halide_imm_types
from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage
from ..api import thread_axis
Expand Down Expand Up @@ -75,7 +75,8 @@ def __init__(self, args, usage, func_name=None):
self.args = args[:]
self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.buffers = {}
self.var_buffers = {} # Buffers formed by mutatble variables
self.alloc_buffers = {} # Buffers formed by allocate instructions
self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage
self.func_name = func_name # The name of the function to be lowered
Expand All @@ -87,19 +88,30 @@ def wrap_up_realize(self, node, body):
for key, val in self.usage.items():
if key in self.var_consts.keys():
continue
_, scope, _ = val
if scope == node:
_buf = self.buffers[key]
_, level, _ = val
if level == node:
if key in self.var_buffers.keys():
_buf = self.var_buffers[key]
_scope = 'global'
else:
_buf, _scope = self.alloc_buffers[key]
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_one = make_range_one()
_true = make_const_true()
body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body)
_true = _api.convert(True)
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
return body


def _check_id_a_buffer(self, s):
if s not in self._args.keys():
def _get_buffer_from_id(self, s):
if s not in self._args.keys() and s not in self.alloc_buffers.keys():
raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s)
if s in self._args.keys() and s in self.alloc_buffers.keys():
raise ValueError("%s, a buffer cannot be both argument and allocated!" % s)
if s in self._args.keys():
return self._args[s]
return self.alloc_buffers[s][0]



#pylint: disable=invalid-name, missing-docstring
Expand Down Expand Up @@ -138,8 +150,8 @@ def visit_Name(self, node):
if _id not in self.usage.keys():
raise ValueError("This id %s is expected to be a defined variable!" % _id)
# Buffer
if _id in self.buffers.keys():
_buf = self.buffers[_id]
if _id in self.var_buffers.keys():
_buf = self.var_buffers[_id]
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
# Compilation time constant
if _id not in self.var_consts.keys():
Expand All @@ -155,7 +167,9 @@ def visit_Assign(self, node):
if len(node.targets) != 1:
raise ValueError("So far only one-valued assignment is supported!")
lhs = node.targets[0]
rhs = _ir_pass.Simplify(self.visit(node.value))
rhs = self.visit(node.value)
if isinstance(rhs, _expr.Expr):
rhs = _ir_pass.Simplify(rhs)
if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later
lhs_ = lhs
Expand All @@ -166,25 +180,31 @@ def visit_Assign(self, node):
if decl == lhs_:
if lhs in self.var_consts.keys():
raise ValueError("BUG: A constant cannot be overwritten!")
if lhs in self.buffers.keys():
if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys():
raise ValueError("BUG: This value should not be defined before this point!")
if isinstance(rhs, tuple):
shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, scope)
return make_nop()
if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs
else:
self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
if lhs in self.var_consts.keys():
return make_nop()
else:
if lhs not in self.buffers.keys():
raise ValueError("BUG: This value should be defined before!")
return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
if lhs not in self.var_buffers.keys():
raise ValueError("BUG: This variable should be defined before!")
tgt = self.var_buffers[lhs]
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else:
lhs = self.visit(lhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later
self._check_id_a_buffer(lhs.name)
return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args)
buf = self._get_buffer_from_id(lhs.name)
return _make.Provide(buf.op, 0, rhs, lhs.args)


def visit_Index(self, node):
Expand All @@ -197,8 +217,7 @@ def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
array = node.value.id
self._check_id_a_buffer(array)
_buf = self._args[array]
_buf = self._get_buffer_from_id(array)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
elif isinstance(node.value, ast.Attribute):
if not isinstance(node.value.value, ast.Name):
Expand All @@ -211,8 +230,8 @@ def visit_Subscript(self, node):
#TODO: maybe support non-constant value later?
if not isinstance(args, (_expr.IntImm, _expr.UIntImm)):
raise ValueError("So far only constant shape access supported!")
self._check_id_a_buffer(node.value.value.id)
return self._args[node.value.value.id].shape[args.value]
buf = self._get_buffer_from_id(node.value.value.id)
return buf.shape[args.value]
else:
raise ValueError("Not supported yet!")

Expand Down Expand Up @@ -303,8 +322,30 @@ def visit_Call(self, node):
elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id == 'allocate':
#TODO: Support it later!
return make_nop()
if not isinstance(node.args[0], ast.Tuple):
raise ValueError("allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts)
for i in shape:
if not isinstance(i, _expr.Expr):
raise ValueError("The shape should be an expression")
if n > 1:
if not isinstance(node.args[1], ast.Str):
raise ValueError("The data type should be an string")
dtype = node.args[1].s
else:
dtype = 'float32'
if n > 2:
if not isinstance(node.args[2], ast.Str):
raise ValueError("The data type should be an string")
scope = node.args[2].s
else:
scope = 'global'
return (shape, dtype, scope)
elif func_id == 'max' or func_id == 'min':
if n != 2:
raise ValueError("Max/Min function should have 2 elements")
a, b = self.visit(node.args[0]), self.visit(node.args[1])
return getattr(_make, func_id.title())(a, b)
else:
raise ValueError("Function call not supported yet!")

Expand All @@ -317,8 +358,10 @@ def visit_For(self, node):
if iter_var is None:
if for_type is None:
raise ValueError("The loop bind function parse error!")
iter_var = _api.var(_name)
self.loops_above[_name] = iter_var
offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
offset = iter_var + low
self.loops_above[_name] = offset
else:
if for_type is not None:
raise ValueError("The loop iterating function parse error!")
Expand All @@ -328,7 +371,7 @@ def visit_For(self, node):
if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else:
res = _make.For(iter_var, low, ext, for_type, 0, _body)
res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body)
self.loops_above.pop(_name)
return res

Expand Down
10 changes: 0 additions & 10 deletions python/tvm/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ def make_nop():
return _make.Evaluate(_api.const(0, dtype='int32'))


def make_range_one():
"""Returns a [0, 1] range node in HalideIR."""
return _make.range_by_min_extent(0, 1)


def make_const_true():
"""Returns a constant True node in HalideIR."""
return _api.convert(True)


def _pruned_source(func):
"""Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n')
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/hybrid/var_decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def visit_Call(self, node):
#No function pointer supported so far
if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id")
if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range':
func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']:
raise ValueError("Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
Expand All @@ -64,7 +65,6 @@ def visit_Name(self, node):
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
loop = self.scope_level[-1]
usage.add(type(node.ctx))
self.status[node.id] = (decl, loop, usage)

Expand Down
Loading

0 comments on commit f721a64

Please sign in to comment.