Skip to content

Commit

Permalink
[Relay][Pass]Improve memory_allocation pass to support multiple i/o d…
Browse files Browse the repository at this point in the history
…ynamic kernels (apache#4595)

* Add more shape funcs

* Fix test

* Enhance test_any_concat

* Fix pylint

* Minor fix test

* Fix pylint

* Minor refactor

* Add test any for elemwise
  • Loading branch information
kevinthesun authored and alexwong committed Feb 28, 2020
1 parent 700ca8a commit 8f3fbb8
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 43 deletions.
38 changes: 23 additions & 15 deletions python/tvm/relay/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""
A pass for manifesting explicit memory allocations.
"""
Expand Down Expand Up @@ -173,33 +173,46 @@ def visit_call(self, call):
new_args = [self.visit(arg) for arg in call.args]
ins = expr.Tuple(new_args)
ret_type = call.checked_type
view = LinearizeRetType(ret_type)
out_types = view.unpack()

is_dynamic = ret_type.is_dynamic()
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()

if is_dynamic:
assert isinstance(ret_type, ty.TensorType)
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(call.op, self.target_host)
input_states = cfunc.shape_func_param_states

is_inputs = []
input_pos = 0
for i, (arg, state) in enumerate(zip(new_args, input_states)):
state = int(state)
# Pass Shapes
if state == 2:
sh_of = self.visit(self.shape_of(arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(i), sh_of))
if isinstance(arg.type_annotation, ty.TupleType):
for j in range(len(arg.type_annotation.fields)):
let_in_arg = scope.let("in_arg_{0}".format(input_pos + j),
expr.TupleGetItem(arg, j))
sh_of = self.visit(self.shape_of(let_in_arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos + j), sh_of))
input_pos += len(arg.type_annotation.fields)
else:
sh_of = self.visit(self.shape_of(arg))
shape_func_ins.append(
scope.let("in_shape_{0}".format(input_pos), sh_of))
input_pos += 1
is_inputs.append(0)
# Pass Inputs
elif state == 1:
new_arg = self.visit(arg)
shape_func_ins.append(
scope.let("in_shape_{0}".format(i), new_arg))
scope.let("in_shape_{0}".format(input_pos), new_arg))
input_pos += 1
is_inputs.append(1)
# TODO(@jroesch): handle 3rd case
else:
Expand All @@ -219,9 +232,6 @@ def visit_call(self, call):

scope.let("shape_func", shape_call)

out_types = []
out_types.append(call.checked_type)

storages = []
for out_shape, out_type in zip(out_shapes, out_types):
size = self.compute_storage_in_relay(
Expand All @@ -242,15 +252,13 @@ def visit_call(self, call):
alloc = scope.let("out_{i}".format(i=i), alloc)
outs.append(alloc)

invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs))
tuple_outs = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, tuple_outs)
scope.let("", invoke)
return outs[0]
return outs[0] if len(outs) == 1 else tuple_outs
else:
view = LinearizeRetType(ret_type)
out_tys = view.unpack()

outs = []
for i, out_ty in enumerate(out_tys):
for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)

Expand Down
21 changes: 14 additions & 7 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from topi.util import get_const_tuple
from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern
from ...hybrid import script
from ...api import convert

schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
Expand Down Expand Up @@ -120,20 +122,20 @@ def _cast_shape_function(x):
def cast_shape_func(attrs, inputs, out_ndims):
return [_cast_shape_function(*inputs)]

# shape func
@script
def _full_shape_func(x):
out_ndim = len(x)
def _full_shape_func(shape):
out_ndim = len(shape)
out = output_tensor((out_ndim,), "int64")
for i in const_range(out_ndim):
out[i] = x[i]
out[i] = int64(shape[i])
return out

def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
"""
return [_full_shape_func(*inputs)]
shape = get_const_tuple(attrs.shape)
return [_full_shape_func(convert(shape))]

@script
def _broadcast_shape_func(x, y, ndim):
Expand Down Expand Up @@ -177,9 +179,11 @@ def elemwise_shape_func(attrs, inputs, _):

register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
Expand All @@ -196,6 +200,9 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
register_shape_func("maximum", False, broadcast_shape_func)
register_shape_func("minimum", False, broadcast_shape_func)

register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func)
78 changes: 59 additions & 19 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,24 +452,8 @@ def transpose_shape_func(attrs, inputs, _):
@script
def _squeeze_shape_func(data_shape, keep_axes):
out = output_tensor((len(keep_axes),), "int64")
if len(keep_axes) == 0:
out_size = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out_size += 1

if out_size == 0:
out_size = 1
out = output_tensor((out_size,), "int64")
out[0] = int64(1)
pos = 0
for i in const_range(data_shape.shape[0]):
if data_shape[i] != 1:
out[pos] = data_shape[i]
pos += 1
else:
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]
for i in const_range(len(keep_axes)):
out[i] = data_shape[keep_axes[i]]

return out

Expand All @@ -485,7 +469,16 @@ def squeeze_shape_func(attrs, inputs, _):
if i not in axis:
keep_axes.append(i)

return [_squeeze_shape_func(inputs[0], convert(keep_axes))]
# Due to current relay type system, it is possible even
# a static kernel function needs shape function. To handle
# this case, we allow axis to be None in squeeze shape func
# for now.
# TODO(kevinthesun): Enhance relay type system to avoid this.
if keep_axes:
out = _squeeze_shape_func(inputs[0], convert(keep_axes))
else:
out = tvm.compute((), lambda *indices: 0)
return [out]

@script
def _reshape_like_shape_func(target_shape):
Expand Down Expand Up @@ -527,9 +520,56 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim):

@_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _):
"""
Shape function for tile op.
"""
reps = get_const_tuple(attrs.reps)
ndim = inputs[0].shape[0].value
rndim = len(reps)
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], convert(reps), convert(ndim),
convert(tndim), convert(rndim))]

@script
def _split_shape_func(data_shape, index, indices_or_sections, axis):
out = output_tensor((data_shape.shape[0],), "int64")
if len(indices_or_sections) == 1:
for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
else:
out[i] = data_shape[i]
else:
start = int64(0)
if index > 0:
start = int64(indices_or_sections[index - 1])
end = data_shape[axis]
if index < len(indices_or_sections):
end = int64(indices_or_sections[index])
for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = end - start
else:
out[i] = data_shape[i]
return out

@_reg.register_shape_func("split", False)
def split_shape_func(attrs, inputs, _):
"""
Shape function for split op.
"""
if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)):
indices_or_sections = get_const_int(attrs.indices_or_sections)
else:
indices_or_sections = get_const_tuple(attrs.indices_or_sections)

axis = get_const_int(attrs.axis)

num_out = indices_or_sections if isinstance(indices_or_sections, int) \
else len(indices_or_sections) + 1
if isinstance(indices_or_sections, int):
indices_or_sections = [indices_or_sections]
return [_split_shape_func(inputs[0],
convert(i),
convert(indices_or_sections),
convert(axis)) for i in range(num_out)]
42 changes: 40 additions & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def test_any_broadcast():
verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)

def verify_any_elemwise(x_shape, x_np_shape, op, np_op):
dtype = 'float32'
x = relay.var('x', shape=x_shape, dtype=dtype)
mod = relay.module.Module()
mod["main"] = relay.Function([x], op(x))
x_np = np.random.uniform(size=x_np_shape).astype(dtype)
res_np = np_op(x_np)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np)
tvm.testing.assert_allclose(result.asnumpy(), res_np)

def test_any_elemwise():
verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt)
verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative)
verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp)

def test_any_broadcast_fail():
# Test broadcast with incompatible values at runtime
Expand Down Expand Up @@ -107,12 +123,14 @@ def test_any_full():
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
y = relay.var('y', shape=(1, 2), dtype="float32")
z = relay.op.concatenate([x, y], axis=0)
xx = x - relay.expr.const(3.0)
yy = y * relay.expr.const(5.0)
z = relay.op.concatenate([xx, yy], axis=0)
mod = relay.module.Module()
mod["main"] = relay.Function([x, y], z)
x_np = np.random.uniform(size=(3, 2)).astype('float32')
y_np = np.random.uniform(size=(1, 2)).astype('float32')
ref = np.concatenate([x_np, y_np], axis=0)
ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(x_np, y_np)
Expand Down Expand Up @@ -417,6 +435,24 @@ def test_any_global_pool2d():
verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4),
"NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4))

def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape):
mod = relay.Module()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.split(data, indices_or_sections, axis)
mod["main"] = relay.Function([data], y.astuple())
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
for kind in ["vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
for ret, ref_ret in zip(result, ref_out_shape):
assert ret.asnumpy().shape == ref_ret, \
"Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape))

def test_any_split():
verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])

def test_any_batch_flatten():
mod = relay.Module()
dtype = "float32"
Expand Down Expand Up @@ -601,11 +637,13 @@ def _body(i, st):
if __name__ == "__main__":
test_any_full()
test_any_broadcast()
test_any_elemwise()
test_any_broadcast_fail()
test_any_concat()
test_any_reshape()
test_any_take()
test_any_tile()
test_any_split()
test_any_shape_of()
test_any_reduce()
test_any_layout_transform()
Expand Down

0 comments on commit 8f3fbb8

Please sign in to comment.