From 8f3fbb831c10bd0c5729fa4fa3565b7f6a2efcd9 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 3 Jan 2020 22:19:00 -0800 Subject: [PATCH] [Relay][Pass]Improve memory_allocation pass to support multiple i/o dynamic kernels (#4595) * Add more shape funcs * Fix test * Enhance test_any_concat * Fix pylint * Minor fix test * Fix pylint * Minor refactor * Add test any for elemwise --- python/tvm/relay/memory_alloc.py | 38 +++++++++------ python/tvm/relay/op/_tensor.py | 21 ++++++--- python/tvm/relay/op/_transform.py | 78 +++++++++++++++++++++++-------- tests/python/relay/test_any.py | 42 ++++++++++++++++- 4 files changed, 136 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/memory_alloc.py b/python/tvm/relay/memory_alloc.py index 9de5431fa0aa..a8d1a30ade10 100644 --- a/python/tvm/relay/memory_alloc.py +++ b/python/tvm/relay/memory_alloc.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return,invalid-name,len-as-condition +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks """ A pass for manifesting explicit memory allocations. """ @@ -173,6 +173,8 @@ def visit_call(self, call): new_args = [self.visit(arg) for arg in call.args] ins = expr.Tuple(new_args) ret_type = call.checked_type + view = LinearizeRetType(ret_type) + out_types = view.unpack() is_dynamic = ret_type.is_dynamic() # TODO(@jroesch): restore this code, more complex then it seems @@ -180,26 +182,37 @@ def visit_call(self, call): # is_dynamic = is_dynamic or arg.checked_type.is_dynamic() if is_dynamic: - assert isinstance(ret_type, ty.TensorType) shape_func_ins = [] engine = compile_engine.get() cfunc = engine.lower_shape_func(call.op, self.target_host) input_states = cfunc.shape_func_param_states is_inputs = [] + input_pos = 0 for i, (arg, state) in enumerate(zip(new_args, input_states)): state = int(state) # Pass Shapes if state == 2: - sh_of = self.visit(self.shape_of(arg)) - shape_func_ins.append( - scope.let("in_shape_{0}".format(i), sh_of)) + if isinstance(arg.type_annotation, ty.TupleType): + for j in range(len(arg.type_annotation.fields)): + let_in_arg = scope.let("in_arg_{0}".format(input_pos + j), + expr.TupleGetItem(arg, j)) + sh_of = self.visit(self.shape_of(let_in_arg)) + shape_func_ins.append( + scope.let("in_shape_{0}".format(input_pos + j), sh_of)) + input_pos += len(arg.type_annotation.fields) + else: + sh_of = self.visit(self.shape_of(arg)) + shape_func_ins.append( + scope.let("in_shape_{0}".format(input_pos), sh_of)) + input_pos += 1 is_inputs.append(0) # Pass Inputs elif state == 1: new_arg = self.visit(arg) shape_func_ins.append( - scope.let("in_shape_{0}".format(i), new_arg)) + scope.let("in_shape_{0}".format(input_pos), new_arg)) + input_pos += 1 is_inputs.append(1) # TODO(@jroesch): handle 3rd case else: @@ -219,9 +232,6 @@ def visit_call(self, call): scope.let("shape_func", shape_call) - out_types = [] - out_types.append(call.checked_type) - storages = [] for out_shape, out_type in zip(out_shapes, out_types): size = self.compute_storage_in_relay( @@ -242,15 +252,13 @@ def visit_call(self, call): alloc = scope.let("out_{i}".format(i=i), alloc) outs.append(alloc) - invoke = self.invoke_tvm(call.op, ins, expr.Tuple(outs)) + tuple_outs = expr.Tuple(outs) + invoke = self.invoke_tvm(call.op, ins, tuple_outs) scope.let("", invoke) - return outs[0] + return outs[0] if len(outs) == 1 else tuple_outs else: - view = LinearizeRetType(ret_type) - out_tys = view.unpack() - outs = [] - for i, out_ty in enumerate(out_tys): + for i, out_ty in enumerate(out_types): out = self.make_static_allocation(scope, out_ty, i) outs.append(out) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 114ff2a14f70..b4a3697ad8f1 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -18,9 +18,11 @@ """Backend compiler related feature registration""" from __future__ import absolute_import import topi +from topi.util import get_const_tuple from .op import register_compute, register_schedule, register_pattern, register_shape_func from .op import schedule_injective, OpPattern from ...hybrid import script +from ...api import convert schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective @@ -120,20 +122,20 @@ def _cast_shape_function(x): def cast_shape_func(attrs, inputs, out_ndims): return [_cast_shape_function(*inputs)] -# shape func @script -def _full_shape_func(x): - out_ndim = len(x) +def _full_shape_func(shape): + out_ndim = len(shape) out = output_tensor((out_ndim,), "int64") for i in const_range(out_ndim): - out[i] = x[i] + out[i] = int64(shape[i]) return out def full_shape_func(attrs, inputs, out_ndims): """ Shape func for zeros, zeros_like, ones, ones_like. """ - return [_full_shape_func(*inputs)] + shape = get_const_tuple(attrs.shape) + return [_full_shape_func(convert(shape))] @script def _broadcast_shape_func(x, y, ndim): @@ -177,9 +179,11 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("cast", False, cast_shape_func) register_shape_func("zeros", False, full_shape_func) -register_shape_func("zeros_like", False, full_shape_func) +register_shape_func("zeros_like", False, elemwise_shape_func) register_shape_func("ones", False, full_shape_func) -register_shape_func("ones_like", False, full_shape_func) +register_shape_func("ones_like", False, elemwise_shape_func) +register_shape_func("full", False, full_shape_func) +register_shape_func("full_like", False, elemwise_shape_func) register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) @@ -196,6 +200,9 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("less_equal", False, broadcast_shape_func) register_shape_func("greater", False, broadcast_shape_func) register_shape_func("greater_equal", False, broadcast_shape_func) +register_shape_func("maximum", False, broadcast_shape_func) +register_shape_func("minimum", False, broadcast_shape_func) register_shape_func("sqrt", False, elemwise_shape_func) register_shape_func("negative", False, elemwise_shape_func) +register_shape_func("exp", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index de708fb09bf8..9f32c250ce6f 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -452,24 +452,8 @@ def transpose_shape_func(attrs, inputs, _): @script def _squeeze_shape_func(data_shape, keep_axes): out = output_tensor((len(keep_axes),), "int64") - if len(keep_axes) == 0: - out_size = 0 - for i in const_range(data_shape.shape[0]): - if data_shape[i] != 1: - out_size += 1 - - if out_size == 0: - out_size = 1 - out = output_tensor((out_size,), "int64") - out[0] = int64(1) - pos = 0 - for i in const_range(data_shape.shape[0]): - if data_shape[i] != 1: - out[pos] = data_shape[i] - pos += 1 - else: - for i in const_range(len(keep_axes)): - out[i] = data_shape[keep_axes[i]] + for i in const_range(len(keep_axes)): + out[i] = data_shape[keep_axes[i]] return out @@ -485,7 +469,16 @@ def squeeze_shape_func(attrs, inputs, _): if i not in axis: keep_axes.append(i) - return [_squeeze_shape_func(inputs[0], convert(keep_axes))] + # Due to current relay type system, it is possible even + # a static kernel function needs shape function. To handle + # this case, we allow axis to be None in squeeze shape func + # for now. + # TODO(kevinthesun): Enhance relay type system to avoid this. + if keep_axes: + out = _squeeze_shape_func(inputs[0], convert(keep_axes)) + else: + out = tvm.compute((), lambda *indices: 0) + return [out] @script def _reshape_like_shape_func(target_shape): @@ -527,9 +520,56 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim): @_reg.register_shape_func("tile", False) def tile_shape_func(attrs, inputs, _): + """ + Shape function for tile op. + """ reps = get_const_tuple(attrs.reps) ndim = inputs[0].shape[0].value rndim = len(reps) tndim = ndim if ndim > rndim else rndim return [_tile_shape_func(inputs[0], convert(reps), convert(ndim), convert(tndim), convert(rndim))] + +@script +def _split_shape_func(data_shape, index, indices_or_sections, axis): + out = output_tensor((data_shape.shape[0],), "int64") + if len(indices_or_sections) == 1: + for i in const_range(data_shape.shape[0]): + if i == axis: + out[i] = ceil_div(data_shape[axis], indices_or_sections[0]) + else: + out[i] = data_shape[i] + else: + start = int64(0) + if index > 0: + start = int64(indices_or_sections[index - 1]) + end = data_shape[axis] + if index < len(indices_or_sections): + end = int64(indices_or_sections[index]) + for i in const_range(data_shape.shape[0]): + if i == axis: + out[i] = end - start + else: + out[i] = data_shape[i] + return out + +@_reg.register_shape_func("split", False) +def split_shape_func(attrs, inputs, _): + """ + Shape function for split op. + """ + if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)): + indices_or_sections = get_const_int(attrs.indices_or_sections) + else: + indices_or_sections = get_const_tuple(attrs.indices_or_sections) + + axis = get_const_int(attrs.axis) + + num_out = indices_or_sections if isinstance(indices_or_sections, int) \ + else len(indices_or_sections) + 1 + if isinstance(indices_or_sections, int): + indices_or_sections = [indices_or_sections] + return [_split_shape_func(inputs[0], + convert(i), + convert(indices_or_sections), + convert(axis)) for i in range(num_out)] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d7246da89226..a30326c8615b 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -59,6 +59,22 @@ def test_any_broadcast(): verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add) verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add) +def verify_any_elemwise(x_shape, x_np_shape, op, np_op): + dtype = 'float32' + x = relay.var('x', shape=x_shape, dtype=dtype) + mod = relay.module.Module() + mod["main"] = relay.Function([x], op(x)) + x_np = np.random.uniform(size=x_np_shape).astype(dtype) + res_np = np_op(x_np) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np) + tvm.testing.assert_allclose(result.asnumpy(), res_np) + +def test_any_elemwise(): + verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt) + verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative) + verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp) def test_any_broadcast_fail(): # Test broadcast with incompatible values at runtime @@ -107,12 +123,14 @@ def test_any_full(): def test_any_concat(): x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") y = relay.var('y', shape=(1, 2), dtype="float32") - z = relay.op.concatenate([x, y], axis=0) + xx = x - relay.expr.const(3.0) + yy = y * relay.expr.const(5.0) + z = relay.op.concatenate([xx, yy], axis=0) mod = relay.module.Module() mod["main"] = relay.Function([x, y], z) x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') - ref = np.concatenate([x_np, y_np], axis=0) + ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(x_np, y_np) @@ -417,6 +435,24 @@ def test_any_global_pool2d(): verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any(), 4), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 1, 1, 4)) +def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, ref_out_shape): + mod = relay.Module() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.split(data, indices_or_sections, axis) + mod["main"] = relay.Function([data], y.astuple()) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + for kind in ["vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + for ret, ref_ret in zip(result, ref_out_shape): + assert ret.asnumpy().shape == ref_ret, \ + "Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape)) + +def test_any_split(): + verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)]) + verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)]) + def test_any_batch_flatten(): mod = relay.Module() dtype = "float32" @@ -601,11 +637,13 @@ def _body(i, st): if __name__ == "__main__": test_any_full() test_any_broadcast() + test_any_elemwise() test_any_broadcast_fail() test_any_concat() test_any_reshape() test_any_take() test_any_tile() + test_any_split() test_any_shape_of() test_any_reduce() test_any_layout_transform()