From da75d85cdce6fa189f3662793e0a68e0f84309f1 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 6 Aug 2020 08:46:58 -0700 Subject: [PATCH 1/5] [Relay][Dynamic] OneHot operation (#6209) * Dynamic OneHot Op * refactor dynamic_to_static * add onehot to dynamic_to_static pass --- include/tvm/topi/transform.h | 19 +-- python/tvm/relay/op/dyn/_transform.py | 35 +++++- python/tvm/relay/op/transform.py | 15 ++- src/relay/op/dyn/tensor/transform.cc | 70 +++++++++++ src/relay/op/make_op.h | 2 + src/relay/transforms/dynamic_to_static.cc | 113 ++++++++++++------ .../relay/dyn/test_dynamic_op_level10.py | 64 ++++++++-- .../relay/test_pass_dynamic_to_static.py | 28 +++++ 8 files changed, 285 insertions(+), 61 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index cd19436d4103..19b2ef4aad74 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1421,22 +1421,25 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, * \param depth depth of the one-hot dimension. * \param axis axis to fill. * \param dtype data type of the output tensor. + * \param oshape shape of the output tensor. * \param name output tensor name. * \param tag output tensor tag. * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType& dtype, + Array oshape = Array(), const std::string name = "T_one_hot", const std::string tag = kInjective) { - Array oshape; - int ndim = indices->shape.size() + 1; - int indices_index = 0; int true_axis = (axis == -1) ? indices->shape.size() : axis; - for (int i = 0; i < ndim; i++) { - if (i == true_axis) { - oshape.push_back(Integer(depth)); - } else { - oshape.push_back(indices->shape[indices_index++]); + if (oshape.size() == 0) { + int ndim = indices->shape.size() + 1; + int indices_index = 0; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Integer(depth)); + } else { + oshape.push_back(indices->shape[indices_index++]); + } } } diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index e2704bc24e62..3a80f5aec957 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -25,11 +25,13 @@ _reg.register_broadcast_schedule("dyn.broadcast_to") _reg.register_injective_schedule("dyn.reshape") _reg.register_broadcast_schedule("dyn.tile") +_reg.register_injective_schedule("dyn.one_hot") + @script def _reshape_shape_func_input_data(data, newshape, ndim): - out = output_tensor((ndim,), "int64") - data_shape = allocate((len(data.shape),), "int64") + out = output_tensor((ndim, ), "int64") + data_shape = allocate((len(data.shape), ), "int64") for x in const_range(len(data.shape)): data_shape[x] = int64(data.shape[x]) src_idx = 0 @@ -59,7 +61,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim): elif newshape[i] == -3: assert data_shape.shape[0] - src_idx > 1, \ "Not enough dims in input shape for -3" - out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] + out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1] src_idx += 2 dst_idx += 1 elif newshape[i] == -4: @@ -82,6 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim): out[infer_idx] = old_size // new_size return out + @_reg.register_shape_func("dyn.reshape", True) def dynamic_reshape_shape_func(attrs, inputs, out_ndims): return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] @@ -89,7 +92,7 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims): @script def _tile_shape_func(data, reps, ndim, tndim, rndim): - out = output_tensor((tndim,), "int64") + out = output_tensor((tndim, ), "int64") if ndim == rndim: for i in const_range(tndim): @@ -120,5 +123,25 @@ def tile_shape_func(attrs, inputs, _): ndim = len(inputs[0].shape) rndim = inputs[1].shape[0].value tndim = ndim if ndim > rndim else rndim - return [_tile_shape_func(inputs[0], reps, convert(ndim), - convert(tndim), convert(rndim))] + return [_tile_shape_func(inputs[0], reps, convert(ndim), convert(tndim), convert(rndim))] + + +@script +def _onehot_shape_func(dshape, k, axis): + ndim = len(dshape) + 1 + out = output_tensor((ndim, ), "int64") + for i in const_range(axis): + out[i] = int64(dshape[i]) + out[axis] = int64(k[0]) + for j in const_range(axis + 1, ndim): + out[j] = int64(dshape[j - 1]) + return out + + +@_reg.register_shape_func("dyn.one_hot", True) +def one_hot_shape_func(attrs, inputs, _): + """ + Shape function for dyn.one_hot op. + """ + axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis + return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6f23af2c0cee..5e5b86717be2 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -148,6 +148,7 @@ def squeeze(data, axis=None): """ return _make.squeeze(data, axis) + def reshape(data, newshape): """Reshape the input array. @@ -228,6 +229,7 @@ def reshape(data, newshape): newshape = tempshape return _make.reshape(data, list(newshape)) + def argwhere(condition): """Find the indices of elements of a tensor that are non-zero. @@ -251,6 +253,7 @@ def argwhere(condition): """ return _make.argwhere(condition) + def scatter(data, indices, updates, axis): """Update data at positions defined by indices with values in updates @@ -275,6 +278,7 @@ def scatter(data, indices, updates, axis): """ return _make.scatter(data, indices, updates, axis) + def scatter_add(data, indices, updates, axis): """Update data by adding values in updates at positions defined by indices @@ -299,6 +303,7 @@ def scatter_add(data, indices, updates, axis): """ return _make.scatter_add(data, indices, updates, axis) + def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes @@ -442,6 +447,7 @@ def arange(start, stop=None, step=None, dtype="float32"): return _make.arange(start, stop, step, dtype) + def meshgrid(data, indexing="ij"): """Create coordinate matrices from coordinate vectors. @@ -482,6 +488,7 @@ def meshgrid(data, indexing="ij"): ret_size = len(data) return TupleWrapper(_make.meshgrid(Tuple(data), indexing), ret_size) + def repeat(data, repeats, axis): """Repeats elements of an array. By default, repeat flattens the input array into 1-D and then repeats the elements. @@ -668,6 +675,7 @@ def where(condition, x, y): """ return _make.where(condition, x, y) + def broadcast_to(data, shape): """Return a scalar value array with the same type, broadcast to the provided shape. @@ -693,6 +701,7 @@ def broadcast_to(data, shape): shape = list(shape) return _make.broadcast_to(data, shape) + def broadcast_to_like(data, broadcast_type): """Return a scalar value array with the same shape and type as the input array. @@ -1053,6 +1062,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): """ return _make.sequence_mask(data, valid_length, mask_value, axis) + def one_hot(indices, on_value, off_value, depth, axis, dtype): """ Returns a one-hot tensor where the locations repsented by indices take value on_value, @@ -1070,7 +1080,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): off_value : relay.Expr Value to fill at all other positions besides indices. - depth : int + depth : int or relay.Expr Depth of the one-hot dimension. axis : int @@ -1095,6 +1105,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 1, 0], [0, 0, 1]] """ + if isinstance(depth, Expr): + return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype) return _make.one_hot(indices, on_value, off_value, depth, axis, dtype) @@ -1120,6 +1132,7 @@ def unravel_index(indices, shape): return _make.unravel_index(indices, shape) + def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0): """Converts a sparse representation into a dense tensor. diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 2bb87acd7dce..d2d6d69382e0 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -304,6 +304,76 @@ RELAY_REGISTER_OP("dyn.ones") .set_support_level(3) .add_type_rel("DynamicInitOp", InitOpRel); +bool OneHotRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [indices, on_value, off_value, result] + CHECK_EQ(types.size(), 5); + const auto* indices = types[0].as(); + CHECK(indices); + + const auto param = attrs.as(); + + Array oshape; + int ndim = indices->shape.size() + 1; + int indices_index = 0; + int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Any()); + } else { + oshape.push_back(indices->shape[indices_index++]); + } + } + + reporter->Assign(types[4], TensorType(oshape, param->dtype)); + return true; +} + +Array OneHotCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + const auto* out_ttype = out_type.as(); + return Array{topi::one_hot(inputs[0], inputs[1](), inputs[2](), -1, param->axis, + param->dtype, out_ttype->shape)}; +} + +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, Expr depth, int axis, DataType dtype) { + auto attrs = make_object(); + attrs->axis = axis; + attrs->dtype = dtype; + static const Op& op = Op::Get("dyn.one_hot"); + return Call(op, {indices, on_value, off_value, depth}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.one_hot").set_body_typed(MakeOneHot); + +RELAY_REGISTER_OP("dyn.one_hot") + .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. Final dimension is x depth. + + **indices** Locations to set to 1. + + **on_value** Value to fill at indices. + + **off_value** Value to fill at all other positions besides indices. + + **depth** Depth of the one-hot dimension. + + **axis** Axis to fill. + + **dtype**)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("indices", "Tensor", "Locations to set to on_value.") + .add_argument("on_value", "Expr", "Value to fill at indices.") + .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") + .add_argument("depth", "Expr", "Value to fill at all other positions besides indices.") + .set_support_level(10) + .add_type_rel("DynOneHot", OneHotRel) + .set_attr("FTVMCompute", OneHotCompute) + .set_attr("TOpPattern", kOutEWiseFusable); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 3b5e9a195957..d2c170d718ae 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -78,6 +78,8 @@ Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool Expr MakeZeros(Array shape, DataType dtype); +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_MAKE_OP_H_ diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index d4de15c6ee5a..8501ee590e0b 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -33,44 +33,82 @@ namespace relay { class DynamicToStaticMutator : public MixedModeMutator { public: - DynamicToStaticMutator() {} + DynamicToStaticMutator() { + op_map_ = { + {Op::Get("dyn.reshape"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[1].as()) { + CHECK_EQ(shape->data->ndim, 1); + return MakeReshape(call_node->args[0], ToVector(shape->data)); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.tile"), + [](const CallNode* call_node) { + if (const ConstantNode* reps = call_node->args[1].as()) { + CHECK_EQ(reps->data->ndim, 1); + return MakeTile(call_node->args[0], ToVector(reps->data)); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.topk"), + [](const CallNode* call_node) { + if (const ConstantNode* k = call_node->args[1].as()) { + const TopKAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), + param->axis, param->ret_type, param->is_ascend, param->dtype); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.broadcast_to"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[1].as()) { + CHECK_EQ(shape->data->ndim, 1); + return MakeBroadCastTo(call_node->args[0], ToVector(shape->data)); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.zeros"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[0].as()) { + const InitOpAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeZeros(ToVector(shape->data), param->dtype); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.ones"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[0].as()) { + const InitOpAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeOnes(ToVector(shape->data), param->dtype); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.one_hot"), + [](const CallNode* call_node) { + if (const ConstantNode* depth = call_node->args[3].as()) { + const OneHotAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2], + static_cast(ToScalar(depth->data, 0)), param->axis, + param->dtype); + } + return Expr(nullptr); + }}, + }; + } private: Expr Rewrite_(const CallNode* pre, const Expr& post) override { - const CallNode* call_node = post.as(); - if (call_node->op == Op::Get("dyn.reshape")) { - if (const ConstantNode* shape = call_node->args[1].as()) { - CHECK_EQ(shape->data->ndim, 1); - return MakeReshape(call_node->args[0], ToVector(shape->data)); - } - } else if (call_node->op == Op::Get("dyn.tile")) { - if (const ConstantNode* reps = call_node->args[1].as()) { - CHECK_EQ(reps->data->ndim, 1); - return MakeTile(call_node->args[0], ToVector(reps->data)); - } - } else if (call_node->op == Op::Get("dyn.topk")) { - if (const ConstantNode* k = call_node->args[1].as()) { - const TopKAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), param->axis, - param->ret_type, param->is_ascend, param->dtype); - } - } else if (call_node->op == Op::Get("dyn.broadcast_to")) { - if (const ConstantNode* shape = call_node->args[1].as()) { - CHECK_EQ(shape->data->ndim, 1); - return MakeBroadCastTo(call_node->args[0], ToVector(shape->data)); - } - } else if (call_node->op == Op::Get("dyn.zeros")) { - if (const ConstantNode* shape = call_node->args[0].as()) { - const InitOpAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeZeros(ToVector(shape->data), param->dtype); - } - } else if (call_node->op == Op::Get("dyn.ones")) { - if (const ConstantNode* shape = call_node->args[0].as()) { - const InitOpAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeOnes(ToVector(shape->data), param->dtype); + if (const CallNode* call_node = post.as()) { + if (op_map_.count(call_node->op)) { + auto out = op_map_[call_node->op](call_node); + if (out.defined()) { + return out; + } } } return post; @@ -83,6 +121,8 @@ class DynamicToStaticMutator : public MixedModeMutator { } return post; } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + op_map_; }; Expr DynamicToStatic(Function f, IRModule m) { @@ -90,6 +130,7 @@ Expr DynamicToStatic(Function f, IRModule m) { Expr expr = f; auto fold_const = transform::FoldConstant(); auto infer_type = transform::InferType(); + DynamicToStaticMutator mutator; Map vars; for (auto kv : m->functions) { vars.Set(kv.second, kv.first); @@ -101,7 +142,7 @@ Expr DynamicToStatic(Function f, IRModule m) { // TODO(mbrookhart): Is it possible to run these passes JUST on the current function? m = infer_type(m); m = fold_const(m); - expr = DynamicToStaticMutator().Mutate(m->functions[gv]); + expr = mutator.Mutate(m->functions[gv]); m->Update(gv, Downcast(expr)); i += 1; } while (pre != expr && i < 1000); diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index d9b23a717f65..95a030f5b8ae 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -19,36 +19,80 @@ """ - import numpy as np import tvm from tvm import relay from tvm.relay.testing import ctx_list, run_infer_type +import tvm.topi.testing import random + def test_dyn_broadcast_to(): dtype = 'uint8' rank = 3 shape_type = 'int64' - dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) - x_shape = (1,) + dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) + x_shape = (1, ) x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) z = relay.broadcast_to(x, dyn_shape) zz = run_infer_type(z) - - assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype) + + assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype) func = relay.Function([x, dyn_shape], z) - + x = np.random.uniform(size=x_shape).astype(dtype) - dyn_shape = (1,)*rank + dyn_shape = (1, ) * rank ref_res = np.broadcast_to(x, dyn_shape) for target, ctx in ctx_list(): - if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU + if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x,np.array(dyn_shape).astype(shape_type)) + op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) -test_dyn_broadcast_to() + +def test_dyn_one_hot(): + def _get_oshape(indices_shape, depth, axis): + oshape = [] + true_axis = len(indices_shape) if axis == -1 else axis + ndim = len(indices_shape) + 1 + indices_index = 0 + for i in range(0, ndim): + if i == true_axis: + oshape.append(depth) + else: + oshape.append(indices_shape[indices_index]) + indices_index += 1 + + return oshape + + def _verify(indices_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + depth_var = relay.var("depth", relay.TensorType((), "int32")) + on_value_const = relay.const(on_value) + off_value_const = relay.const(off_value) + out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype) + func = relay.Function([indices, depth_var], out) + indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") + out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) + for target, ctx in ctx_list(): + if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((3, ), 3, 1, 0, -1, "int32") + _verify((3, ), 3, 1.0, 0.0, -1, "float32") + _verify((2, 2), 5, 2, -2, 0, "int32") + _verify((2, 2), 5, 0.5, -0.5, 1, "float32") + _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + + +if __name__ == "__main__": + test_dyn_broadcast_to() + test_dyn_one_hot() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 8ca788212ff3..a50c9df8cf49 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -22,6 +22,8 @@ from tvm.relay.build_module import bind_params_by_name from tvm.relay.testing import run_infer_type, create_workload, ctx_list +import tvm.topi.testing + def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -222,6 +224,32 @@ def verify_ones_zeros(shape, dtype): verify_ones_zeros((1, 2, 3), 'int64') verify_ones_zeros((9, 8, 3, 4), 'float32') +def test_dynamic_to_static_one_hot(): + def _verify(indices_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + depth_var = relay.const(depth) + on_value_const = relay.const(on_value) + off_value_const = relay.const(off_value) + out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype) + func = relay.Function([indices], out) + + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("one_hot") + + indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") + out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) + verify_func(func2, [indices_np], out_np) + + _verify((3, ), 3, 1, 0, -1, "int32") + _verify((3, ), 3, 1.0, 0.0, -1, "float32") + _verify((2, 2), 5, 2, -2, 0, "int32") + _verify((2, 2), 5, 0.5, -0.5, 1, "float32") + _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + if __name__=="__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() From b485d478f280938cebf5d3072a4420c2cef56c6e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 6 Aug 2020 16:08:19 -0700 Subject: [PATCH 2/5] =?UTF-8?q?[=C2=B5TVM]=20Add=20--runtime=3Dc,=20remove?= =?UTF-8?q?=20micro=5Fdev=20target,=20enable=20LLVM=20backend=20(#6145)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * need to fill address of globals in tvmfuncregistry * llvm func registry generator works! * lint fixes * rm hexdump include * bring bundle_deploy back to life and add to CI * revert gcda additions * git-clang-format * fix check for --system-lib and test_runtime_micro target * fixup compile flags for bundle_deploy CRT and improve robustness * git-clang-format * add debugging info * git-clang-format * initialize ret_values in PackedFunc_Call. * retrigger CI * fix log messages * git-clang-format * remove default for --runtime target opt * put backtrace behind a flag and enable it * simpify ReadString(), fixing bad instruction exception on os x. * git-clang-format * uncomment tests * reorder backtrace ldflags for linux gcc --- apps/bundle_deploy/Makefile | 141 +++++++++++------- apps/bundle_deploy/backtrace.c | 57 +++++++ apps/bundle_deploy/backtrace.h | 30 ++++ apps/bundle_deploy/build_model.py | 69 +++++---- apps/bundle_deploy/bundle.c | 67 +++++---- apps/bundle_deploy/bundle.cc | 1 + apps/bundle_deploy/bundle.h | 2 +- apps/bundle_deploy/bundle_static.c | 12 +- apps/bundle_deploy/demo.cc | 74 ++++++++- apps/bundle_deploy/demo_static.c | 12 +- apps/bundle_deploy/runtime.cc | 37 +++++ apps/bundle_deploy/test.cc | 81 +++++++--- apps/bundle_deploy/test_static.c | 2 +- include/tvm/target/target_kind.h | 6 + src/runtime/crt/Makefile | 4 +- src/runtime/crt/common/crt_runtime_api.c | 2 - src/runtime/crt/common/memory.c | 2 - src/runtime/crt/common/packed_func.c | 2 + src/runtime/crt/graph_runtime/graph_runtime.c | 58 +++++-- src/runtime/crt/graph_runtime/load_json.c | 52 +++---- .../crt/internal/graph_runtime/load_json.h | 4 +- src/support/str_escape.h | 12 +- src/target/func_registry_generator.cc | 43 ++++++ src/target/func_registry_generator.h | 38 +++++ src/target/llvm/codegen_amdgpu.cc | 2 +- src/target/llvm/codegen_cpu.cc | 71 ++++++++- src/target/llvm/codegen_cpu.h | 13 +- src/target/llvm/codegen_llvm.cc | 5 +- src/target/llvm/codegen_llvm.h | 6 +- src/target/llvm/codegen_nvptx.cc | 2 +- src/target/llvm/llvm_module.cc | 15 +- src/target/source/codegen_c_host.cc | 46 +++++- src/target/source/codegen_c_host.h | 11 ++ src/target/target_kind.cc | 11 +- tests/python/unittest/test_runtime_micro.py | 2 +- .../unittest/test_target_codegen_llvm.py | 11 ++ tests/scripts/task_python_integration.sh | 2 +- 37 files changed, 761 insertions(+), 244 deletions(-) create mode 100644 apps/bundle_deploy/backtrace.c create mode 100644 apps/bundle_deploy/backtrace.h create mode 100644 apps/bundle_deploy/runtime.cc create mode 100644 src/target/func_registry_generator.cc create mode 100644 src/target/func_registry_generator.h diff --git a/apps/bundle_deploy/Makefile b/apps/bundle_deploy/Makefile index eeea539d16d2..adb8d3386bdf 100644 --- a/apps/bundle_deploy/Makefile +++ b/apps/bundle_deploy/Makefile @@ -21,13 +21,16 @@ TVM_ROOT=$(shell cd ../..; pwd) CRT_ROOT ?= ../../src/runtime/crt +ENABLE_TVM_PLATFORM_ABORT_BACKTRACE ?= 1 + DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CXXFLAGS = -g -Wall -std=c++14 -O2 -fPIC \ +PKG_COMPILE_OPTS = -g -Wall -O2 -fPIC +PKG_CXXFLAGS = ${PKG_COMPILE_OPTS} -std=c++14 \ -I${TVM_ROOT}/include \ -I${DMLC_CORE}/include \ -I${TVM_ROOT}/3rdparty/dlpack/include \ -Icrt_config -PKG_CFLAGS = -g -Wall -std=c99 -O2 -fPIC \ +PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ -I${TVM_ROOT}/include \ -I${DMLC_CORE}/include \ -I${TVM_ROOT}/3rdparty/dlpack/include \ @@ -37,90 +40,116 @@ PKG_LDFLAGS = -pthread build_dir := build +BACKTRACE_SRCS = +BACKTRACE_LDFLAGS = +BACKTRACE_CFLAGS = +$(ifeq ENABLE_TVM_PLATFORM_ABORT_BACKTRACE,1) +BACKTRACE_SRCS += backtrace.c +BACKTRACE_LDFLAGS += -ldl +BACKTRACE_CFLAGS += -DENABLE_TVM_PLATFORM_ABORT_BACKTRACE +$(endif) + +BACKTRACE_OBJS = $(patsubst %.c,$(build_dir)/%.o,$(BACKTRACE_SRCS)) + +$(ifeq VERBOSE,1) +QUIET ?= +$(else) +QUIET ?= @ +$(endif) -demo_dynamic: $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/cat.bin - TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/cat.bin - TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle_c.so $(build_dir)/cat.bin -test_dynamic: $(build_dir)/test_dynamic $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin - TVM_NUM_THREADS=1 $(build_dir)/test_dynamic $(build_dir)/test_bundle.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin - TVM_NUM_THREADS=1 $(build_dir)/test_dynamic $(build_dir)/test_bundle_c.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin +demo_dynamic: $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/graph_c.json $(build_dir)/params_cpp.bin $(build_dir)/params_c.bin $(build_dir)/cat.bin + $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/params_cpp.bin $(build_dir)/cat.bin + $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle_c.so $(build_dir)/graph_c.json $(build_dir)/params_c.bin $(build_dir)/cat.bin + +test_dynamic: $(build_dir)/test_dynamic $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_data_cpp.bin $(build_dir)/test_output_cpp.bin + $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/test_dynamic $(build_dir)/test_bundle.so $(build_dir)/test_data_cpp.bin $(build_dir)/test_output_cpp.bin $(build_dir)/test_graph_cpp.json $(build_dir)/test_params_cpp.bin + $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/test_dynamic $(build_dir)/test_bundle_c.so $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin demo_static: $(build_dir)/demo_static $(build_dir)/cat.bin - TVM_NUM_THREADS=1 $(build_dir)/demo_static $(build_dir)/cat.bin + $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_static $(build_dir)/cat.bin -test_static: $(build_dir)/test_static $(build_dir)/test_data.bin $(build_dir)/test_output.bin - TVM_NUM_THREADS=1 $(build_dir)/test_static $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin +test_static: $(build_dir)/test_static $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin + $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/test_static $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin $(build_dir)/crt/graph_runtime/libgraph_runtime.a: - cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) graph_runtime + $(QUIET)cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) "EXTRA_CFLAGS=$(PKG_COMPILE_OPTS)" graph_runtime $(build_dir)/crt/common/libcommon.a: - cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) common + $(QUIET)cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) "EXTRA_CFLAGS=$(PKG_COMPILE_OPTS)" common -$(build_dir)/demo_dynamic: demo.cc ${build_dir}/graph.json.c ${build_dir}/params.bin.c - @mkdir -p $(@D) - g++ $(PKG_CXXFLAGS) -o $@ demo.cc -ldl +$(build_dir)/demo_dynamic: demo.cc + $(QUIET)mkdir -p $(@D) + $(QUIET)g++ $(PKG_CXXFLAGS) -o $@ demo.cc $(BACKTRACE_LDFLAGS) -$(build_dir)/test_dynamic: test.cc ${build_dir}/test_graph.json ${build_dir}/test_params.bin - @mkdir -p $(@D) - g++ $(PKG_CXXFLAGS) -o $@ test.cc -ldl +$(build_dir)/test_dynamic: test.cc ${build_dir}/test_graph_c.json ${build_dir}/test_params_c.bin $(BACKTRACE_OBJS) + $(QUIET)mkdir -p $(@D) + $(QUIET)g++ $(PKG_CXXFLAGS) -o $@ test.cc $(BACKTRACE_OBJS) $(BACKTRACE_LDFLAGS) -$(build_dir)/model.o: $(build_dir)/model.c - gcc $(PKG_CFLAGS) -c -o $@ $^ +$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) + $(QUIET)mkdir -p $(@D) + $(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/func_registry.c ${build_dir}/model.o ${build_dir}/graph.json.c ${build_dir}/params.bin.c ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a - @mkdir -p $(@D) - gcc $(PKG_CFLAGS) -o $@ demo_static.c ${build_dir}/bundle_static.o ${build_dir}/func_registry.c ${build_dir}/model.o -lm ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a +$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) + $(QUIET)mkdir -p $(@D) + $(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_LDFLAGS) -$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_func_registry.c ${build_dir}/test_model.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a - @mkdir -p $(@D) - gcc $(PKG_CFLAGS) -o $@ $^ +$(build_dir)/backtrace.o: backtrace.c + $(QUIET)mkdir -p $(@D) + $(QUIET)gcc -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) # Serialize our graph.json file. -$(build_dir)/graph.json.c: $(build_dir)/graph.json - xxd -i $^ > $@ +$(build_dir)/graph_cpp.json.c: $(build_dir)/graph_cpp.json + $(QUIET)xxd -i $^ > $@ -# Serialize our params.bin file. -$(build_dir)/params.bin.c: $(build_dir)/params.bin - xxd -i $^ > $@ +$(build_dir)/graph_c.json.c: $(build_dir)/graph_c.json + $(QUIET)xxd -i $^ > $@ -$(build_dir)/func_registry.c $(build_dir)/model.c $(build_dir)/graph.json $(build_dir)/params.bin $(build_dir)/cat.bin: build_model.py - python3 $< -o $(build_dir) +# Serialize our params.bin file. +$(build_dir)/params_c.bin.c: $(build_dir)/params_c.bin + $(QUIET)xxd -i $^ > $@ -$(build_dir)/test_func_registry.c $(build_dir)/test_model.c $(build_dir)/test_graph.json $(build_dir)/test_params.bin $(build_dir)/test_data.bin $(build_dir)/test_output.bin: build_model.py - python3 $< -o $(build_dir) --test +$(build_dir)/params_cpp.bin.c: $(build_dir)/params_cpp.bin + $(QUIET)xxd -i $^ > $@ -$(build_dir)/test_model.o: $(build_dir)/test_model.c - gcc $(PKG_CFLAGS) -c -o $@ $^ +$(build_dir)/model_c.o $(build_dir)/graph_c.json $(build_dir)/model_cpp.o $(build_dir)/graph_cpp.json $(build_dir)/params.bin $(build_dir)/cat.bin: build_model.py + $(QUIET)python3 $< -o $(build_dir) -$(build_dir)/func_registry.o: $(build_dir)/func_registry.c - gcc $(PKG_CFLAGS) -c -o $@ $^ +$(build_dir)/test_model_c.o $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_model_cpp.o $(build_dir)/test_graph_cpp.json $(build_dir)/test_params_cpp.bin $(build_dir)/test_data_cpp.bin $(build_dir)/test_output_cpp.bin: build_model.py + $(QUIET)python3 $< -o $(build_dir) --test # Build our bundle against the serialized bundle.c API, the runtime.cc API, and # the serialized graph.json and params.bin -$(build_dir)/bundle.so: bundle.cc $(build_dir)/model.o $(build_dir)/func_registry.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a - @mkdir -p $(@D) - g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) +$(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model_cpp.o + $(QUIET)mkdir -p $(@D) + $(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) -$(build_dir)/bundle_c.so: bundle.c runtime.c $(build_dir)/model.o $(build_dir)/func_registry.c - @mkdir -p $(@D) - gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) +$(build_dir)/bundle_c.so: bundle.c $(build_dir)/model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) + $(QUIET)mkdir -p $(@D) + $(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -$(build_dir)/test_bundle.so: bundle.cc runtime.cc $(build_dir)/test_model.o $(build_dir)/test_func_registry.c - @mkdir -p $(@D) - g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) +$(build_dir)/test_bundle.so: bundle.cc runtime.cc $(build_dir)/test_model_cpp.o + $(QUIET)mkdir -p $(@D) + $(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) -$(build_dir)/test_bundle_c.so: bundle.c runtime.c $(build_dir)/test_model.o $(build_dir)/test_func_registry.c - @mkdir -p $(@D) - gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) +$(build_dir)/test_bundle_c.so: bundle.c $(build_dir)/test_model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) + $(QUIET)mkdir -p $(@D) + $(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) $(build_dir)/bundle_static.o: bundle_static.c - @mkdir -p $(@D) - gcc -c $(PKG_CFLAGS) -o $@ $^ + $(QUIET)mkdir -p $(@D) + $(QUIET)gcc -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) clean: - rm -rf $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so $(build_dir)/crt + $(QUIET)rm -rf $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so $(build_dir)/crt cleanall: - rm -rf $(build_dir) + $(QUIET)rm -rf $(build_dir) + +# Don't define implicit rules; they tend to match on logical target names that aren't targets (i.e. bundle_static) +.SUFFIXES: + +.DEFAULT: demo_static demo_dynamic + +test: test_static test_dynamic +.PHONY: test diff --git a/apps/bundle_deploy/backtrace.c b/apps/bundle_deploy/backtrace.c new file mode 100644 index 000000000000..adb9b94959e2 --- /dev/null +++ b/apps/bundle_deploy/backtrace.c @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define _GNU_SOURCE +#include "backtrace.h" + +#include +#include +#include +#include +#include + +const char* g_argv0 = NULL; + +void tvm_platform_abort_backtrace() { + void* trace[200]; + int nptrs = backtrace(trace, sizeof(trace) / sizeof(void*)); + fprintf(stderr, "backtrace: %d\n", nptrs); + if (nptrs < 0) { + perror("backtracing"); + } else { + backtrace_symbols_fd(trace, nptrs, STDOUT_FILENO); + + char cmd_buf[1024]; + for (int i = 0; i < nptrs; i++) { + Dl_info info; + if (dladdr(trace[i], &info)) { + fprintf(stderr, "symbol %d: %s %s %p (%p)\n", i, info.dli_sname, info.dli_fname, + info.dli_fbase, (void*)(trace[i] - info.dli_fbase)); + snprintf(cmd_buf, sizeof(cmd_buf), "addr2line --exe=%s -p -i -a -f %p", g_argv0, + (void*)(trace[i] - info.dli_fbase)); + int result = system(cmd_buf); + if (result < 0) { + perror("invoking backtrace command"); + } + } else { + fprintf(stderr, "symbol %d: %p (unmapped)\n", i, trace[i]); + } + } + } +} diff --git a/apps/bundle_deploy/backtrace.h b/apps/bundle_deploy/backtrace.h new file mode 100644 index 000000000000..1f6346b85703 --- /dev/null +++ b/apps/bundle_deploy/backtrace.h @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +extern const char* g_argv0; + +void tvm_platform_abort_backtrace(void); + +#ifdef __cplusplus +} +#endif diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index 2fe8ef34bfc5..e99623fbdcc0 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -21,10 +21,14 @@ from tvm import relay import tvm from tvm import te -from tvm.micro import func_registry import logging import json +RUNTIMES = { + 'c': '{name}_c.{ext}', + 'c++': '{name}_cpp.{ext}', +} + def build_module(opts): dshape = (1, 3, 224, 224) from mxnet.gluon.model_zoo.vision import get_model @@ -34,21 +38,20 @@ def build_module(opts): func = mod["main"] func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) - with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}): - graph, lib, params = relay.build( - func, 'c', params=params) + for runtime_name, file_format_str in RUNTIMES.items(): + with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}): + graph, lib, params = relay.build( + func, f'llvm --runtime={runtime_name} --system-lib', params=params) - build_dir = os.path.abspath(opts.out_dir) - if not os.path.isdir(build_dir): - os.makedirs(build_dir) + build_dir = os.path.abspath(opts.out_dir) + if not os.path.isdir(build_dir): + os.makedirs(build_dir) - lib.save(os.path.join(build_dir, 'model.c'), 'cc') - with open(os.path.join(build_dir, 'graph.json'), 'w') as f_graph_json: - f_graph_json.write(graph) - with open(os.path.join(build_dir, 'params.bin'), 'wb') as f_params: - f_params.write(relay.save_param_dict(params)) - func_registry.graph_json_to_c_func_registry(os.path.join(build_dir, 'graph.json'), - os.path.join(build_dir, 'func_registry.c')) + lib.save(os.path.join(build_dir, file_format_str.format(name='model', ext='o'))) + with open(os.path.join(build_dir, file_format_str.format(name='graph', ext='json')), 'w') as f_graph_json: + f_graph_json.write(graph) + with open(os.path.join(build_dir, file_format_str.format(name='params', ext='bin')), 'wb') as f_params: + f_params.write(relay.save_param_dict(params)) def build_test_module(opts): import numpy as np @@ -60,26 +63,26 @@ def build_test_module(opts): x_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32') params = {"y": y_data} - with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}): - graph, lib, params = relay.build( - tvm.IRModule.from_expr(func), "c", params=params) - build_dir = os.path.abspath(opts.out_dir) - if not os.path.isdir(build_dir): - os.makedirs(build_dir) - - lib.save(os.path.join(build_dir, 'test_model.c'), 'cc') - with open(os.path.join(build_dir, 'test_graph.json'), 'w') as f_graph_json: - f_graph_json.write(graph) - with open(os.path.join(build_dir, 'test_params.bin'), 'wb') as f_params: - f_params.write(relay.save_param_dict(params)) - with open(os.path.join(build_dir, "test_data.bin"), "wb") as fp: - fp.write(x_data.astype(np.float32).tobytes()) - func_registry.graph_json_to_c_func_registry(os.path.join(build_dir, 'test_graph.json'), - os.path.join(build_dir, 'test_func_registry.c')) - x_output = x_data + y_data - with open(os.path.join(build_dir, "test_output.bin"), "wb") as fp: - fp.write(x_output.astype(np.float32).tobytes()) + for runtime_name, file_format_str in RUNTIMES.items(): + with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}): + graph, lib, lowered_params = relay.build( + tvm.IRModule.from_expr(func), f"llvm --runtime={runtime_name} --system-lib", params=params) + + build_dir = os.path.abspath(opts.out_dir) + if not os.path.isdir(build_dir): + os.makedirs(build_dir) + + lib.save(os.path.join(build_dir, file_format_str.format(name='test_model', ext='o'))) + with open(os.path.join(build_dir, file_format_str.format(name='test_graph', ext='json')), 'w') as f_graph_json: + f_graph_json.write(graph) + with open(os.path.join(build_dir, file_format_str.format(name='test_params', ext='bin')), 'wb') as f_params: + f_params.write(relay.save_param_dict(lowered_params)) + with open(os.path.join(build_dir, file_format_str.format(name="test_data", ext="bin")), "wb") as fp: + fp.write(x_data.astype(np.float32).tobytes()) + x_output = x_data + y_data + with open(os.path.join(build_dir, file_format_str.format(name="test_output", ext="bin")), "wb") as fp: + fp.write(x_output.astype(np.float32).tobytes()) def build_inputs(opts): from tvm.contrib import download diff --git a/apps/bundle_deploy/bundle.c b/apps/bundle_deploy/bundle.c index d86c79e0c1bc..9ff67eaf3ed9 100644 --- a/apps/bundle_deploy/bundle.c +++ b/apps/bundle_deploy/bundle.c @@ -20,6 +20,13 @@ #include #include #include +#include +#include +#include + +#ifdef ENABLE_TVM_ABORT_BACKTRACE +#include "backtrace.h" +#endif /*! \brief macro to do C API call */ #define TVM_CCALL(func) \ @@ -32,7 +39,11 @@ } while (0) TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, - const uint64_t params_size) { + const uint64_t params_size, const char* argv0) { +#ifdef ENABLE_TVM_ABORT_BACKTRACE + g_argv0 = argv0; +#endif + int64_t device_type = kDLCPU; int64_t device_id = 0; @@ -45,50 +56,44 @@ TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, ctx.device_id = device_id; // declare pointers - TVMModuleHandle (*SystemLibraryCreate)(); - TVMModuleHandle (*TVMGraphRuntimeCreate)(const char*, const TVMModuleHandle, const TVMContext*); - int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char*, const uint32_t); - - TVM_CCALL(TVMRuntimeInitialize()); + TVM_CCALL(TVMInitializeRuntime()); + TVMPackedFunc pf; + TVMArgs args = TVMArgs_Create(NULL, NULL, 0); + TVM_CCALL(TVMPackedFunc_InitGlobalFunc(&pf, "runtime.SystemLib", &args)); + TVM_CCALL(TVMPackedFunc_Call(&pf)); - // get pointers - TVM_CCALL(TVMFuncGetGlobal("runtime.SystemLib", (TVMFunctionHandle*)&SystemLibraryCreate)); - TVM_CCALL( - TVMFuncGetGlobal("tvm.graph_runtime.create", (TVMFunctionHandle*)&TVMGraphRuntimeCreate)); + TVMModuleHandle mod_syslib = TVMArgs_AsModuleHandle(&pf.ret_value, 0); // run modules - TVMModuleHandle mod_syslib = SystemLibraryCreate(); - TVMModuleHandle mod = TVMGraphRuntimeCreate(json_data, mod_syslib, &ctx); - TVM_CCALL( - TVMModGetFunction(mod, "load_params", 0, (TVMFunctionHandle*)&TVMGraphRuntime_LoadParams)); - TVMGraphRuntime_LoadParams(mod, params.data, params.size); + TVMGraphRuntime* graph_runtime = TVMGraphRuntime_Create(json_data, mod_syslib, &ctx); + TVMGraphRuntime_LoadParams(graph_runtime, params.data, params.size); - return mod; + return graph_runtime; } TVM_DLL void tvm_runtime_destroy(void* runtime) { - void (*TVMGraphRuntimeRelease)(TVMModuleHandle*); - TVM_CCALL( - TVMFuncGetGlobal("tvm.graph_runtime.release", (TVMFunctionHandle*)&TVMGraphRuntimeRelease)); - TVMGraphRuntimeRelease(&runtime); + TVMGraphRuntime_Release((TVMGraphRuntime**)&runtime); } TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor) { - void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char*, DLTensor*); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.set_input", - (TVMFunctionHandle*)&TVMGraphRuntime_SetInput)); - TVMGraphRuntime_SetInput(runtime, name, tensor); + TVMGraphRuntime* graph_runtime = (TVMGraphRuntime*)runtime; + TVMGraphRuntime_SetInput(graph_runtime, name, tensor); } TVM_DLL void tvm_runtime_run(void* runtime) { - void (*TVMGraphRuntime_Run)(TVMModuleHandle runtime); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.run", (TVMFunctionHandle*)&TVMGraphRuntime_Run)); - TVMGraphRuntime_Run(runtime); + TVMGraphRuntime* graph_runtime = (TVMGraphRuntime*)runtime; + TVMGraphRuntime_Run(graph_runtime); } TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor) { - int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor*); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.get_output", - (TVMFunctionHandle*)&TVMGraphRuntime_GetOutput)); - TVMGraphRuntime_GetOutput(runtime, index, tensor); + TVMGraphRuntime* graph_runtime = (TVMGraphRuntime*)runtime; + TVMGraphRuntime_GetOutput(graph_runtime, index, tensor); +} + +void __attribute__((noreturn)) TVMPlatformAbort(int error_code) { + fprintf(stderr, "TVMPlatformAbort: %d\n", error_code); +#ifdef ENABLE_TVM_ABORT_BACKTRACE + tvm_platform_abort_backtrace(); +#endif + exit(-1); } diff --git a/apps/bundle_deploy/bundle.cc b/apps/bundle_deploy/bundle.cc index d8ff683decc3..e3cc7d1730ce 100644 --- a/apps/bundle_deploy/bundle.cc +++ b/apps/bundle_deploy/bundle.cc @@ -34,6 +34,7 @@ TVM_BUNDLE_FUNCTION void* tvm_runtime_create(const char* build_graph_json, tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); int device_type = kDLCPU; int device_id = 0; + tvm::runtime::Module mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( json_data, mod_syslib, device_type, device_id); TVMByteArray params; diff --git a/apps/bundle_deploy/bundle.h b/apps/bundle_deploy/bundle.h index 80238e1e231a..0d7cab7d5a8d 100644 --- a/apps/bundle_deploy/bundle.h +++ b/apps/bundle_deploy/bundle.h @@ -23,7 +23,7 @@ #include TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, - const uint64_t params_size); + const uint64_t params_size, const char* argv); TVM_DLL void tvm_runtime_destroy(void* runtime); diff --git a/apps/bundle_deploy/bundle_static.c b/apps/bundle_deploy/bundle_static.c index b999a7b63872..6e189b663a16 100644 --- a/apps/bundle_deploy/bundle_static.c +++ b/apps/bundle_deploy/bundle_static.c @@ -22,7 +22,11 @@ #include #include #include +#include +#ifdef ENABLE_TVM_PLATFORM_ABORT_BACKTRACE +#include "backtrace.h" +#endif #include "bundle.h" /*! \brief macro to do C API call */ @@ -36,7 +40,10 @@ } while (0) TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, - const uint64_t params_size) { + const uint64_t params_size, const char* argv0) { +#ifdef ENABLE_TVM_PLATFORM_ABORT_BACKTRACE + g_argv0 = argv0; +#endif int64_t device_type = kDLCPU; int64_t device_id = 0; @@ -86,5 +93,8 @@ TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tens void __attribute__((noreturn)) TVMPlatformAbort(int error_code) { fprintf(stderr, "TVMPlatformAbort: %d\n", error_code); +#ifdef ENABLE_TVM_PLATFORM_ABORT_BACKTRACE + tvm_platform_abort_backtrace(); +#endif exit(-1); } diff --git a/apps/bundle_deploy/demo.cc b/apps/bundle_deploy/demo.cc index 5c210a2cab88..34654c6f55c2 100644 --- a/apps/bundle_deploy/demo.cc +++ b/apps/bundle_deploy/demo.cc @@ -26,9 +26,6 @@ #include #include -#include "build/graph.json.c" -#include "build/params.bin.c" - template auto getFunc(void* bundle, const char* name) { dlerror(); @@ -37,14 +34,75 @@ auto getFunc(void* bundle, const char* name) { return f; } +static int read_all(const char* file_description, const char* file_path, char** out_params, + size_t* params_size) { + FILE* fp = fopen(file_path, "rb"); + if (fp == NULL) { + return 2; + } + + int error = 0; + error = fseek(fp, 0, SEEK_END); + if (error < 0) { + return error; + } + + long file_size = ftell(fp); + if (file_size < 0) { + return (int)file_size; + } else if (file_size == 0 || file_size > (10 << 20)) { // file size should be in (0, 20MB]. + char buf[128]; + snprintf(buf, sizeof(buf), "determing file size: %s", file_path); + perror(buf); + return 2; + } + + if (params_size != NULL) { + *params_size = file_size; + } + + error = fseek(fp, 0, SEEK_SET); + if (error < 0) { + return error; + } + + *out_params = (char*)malloc((unsigned long)file_size); + if (fread(*out_params, file_size, 1, fp) != 1) { + free(*out_params); + *out_params = NULL; + + char buf[128]; + snprintf(buf, sizeof(buf), "reading: %s", file_path); + perror(buf); + return 2; + } + + error = fclose(fp); + if (error != 0) { + free(*out_params); + *out_params = NULL; + } + + return 0; +} + int main(int argc, char** argv) { - assert(argc == 3 && "Usage: demo "); + assert(argc == 5 && "Usage: demo "); auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); - char* json_data = reinterpret_cast(build_graph_json); - char* params_data = reinterpret_cast(build_params_bin); - uint64_t params_size = build_params_bin_len; + char* json_data; + int error = read_all("graph.json", argv[2], &json_data, NULL); + if (error != 0) { + return error; + } + + char* params_data; + size_t params_size; + error = read_all("params.bin", argv[3], ¶ms_data, ¶ms_size); + if (error != 0) { + return error; + } struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); @@ -54,7 +112,7 @@ int main(int argc, char** argv) { gettimeofday(&t1, 0); float input_storage[1 * 3 * 224 * 224]; - FILE* fp = fopen(argv[2], "rb"); + FILE* fp = fopen(argv[3], "rb"); fread(input_storage, 3 * 224 * 224, 4, fp); fclose(fp); diff --git a/apps/bundle_deploy/demo_static.c b/apps/bundle_deploy/demo_static.c index 24aafbaf658b..418ab8ef38a6 100644 --- a/apps/bundle_deploy/demo_static.c +++ b/apps/bundle_deploy/demo_static.c @@ -24,8 +24,8 @@ #include #include -#include "build/graph.json.c" -#include "build/params.bin.c" +#include "build/graph_c.json.c" +#include "build/params_c.bin.c" #include "bundle.h" #define OUTPUT_LEN 1000 @@ -33,14 +33,14 @@ int main(int argc, char** argv) { assert(argc == 2 && "Usage: demo_static "); - char* json_data = (char*)(build_graph_json); - char* params_data = (char*)(build_params_bin); - uint64_t params_size = build_params_bin_len; + char* json_data = (char*)(build_graph_c_json); + char* params_data = (char*)(build_params_c_bin); + uint64_t params_size = build_params_c_bin_len; struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - void* handle = tvm_runtime_create(json_data, params_data, params_size); + void* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]); gettimeofday(&t1, 0); float input_storage[1 * 3 * 224 * 224]; diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc new file mode 100644 index 000000000000..8e294a05775d --- /dev/null +++ b/apps/bundle_deploy/runtime.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "../../src/runtime/c_runtime_api.cc" +#include "../../src/runtime/cpu_device_api.cc" +#include "../../src/runtime/file_util.cc" +#include "../../src/runtime/graph/graph_runtime.cc" +#include "../../src/runtime/library_module.cc" +#include "../../src/runtime/module.cc" +#include "../../src/runtime/ndarray.cc" +#include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" +#include "../../src/runtime/system_library.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc index 882e04be8ef9..daadd7a57e2b 100644 --- a/apps/bundle_deploy/test.cc +++ b/apps/bundle_deploy/test.cc @@ -35,28 +35,59 @@ auto getFunc(void* bundle, const char* name) { return f; } +char* read_all_or_die(const char* name, const char* file_path, size_t* out_size) { + struct stat st; + if (stat(file_path, &st)) { + char err[1024]; + snprintf(err, sizeof(err), "%s: statting file", name); + perror(err); + abort(); + } + if (st.st_size > 1024 * 1024) { + std::cerr << name << ": file is over 1MB limit: " << st.st_size << " bytes" << std::endl; + abort(); + } + + if (out_size != nullptr) { + *out_size = st.st_size; + } + + char* data = (char*)malloc(st.st_size); + FILE* fp = fopen(file_path, "rb"); + size_t bytes_read = 0; + while (bytes_read < st.st_size) { + size_t this_round = fread(data, 1, st.st_size, fp); + if (this_round == 0) { + if (ferror(fp)) { + char err[1024]; + snprintf(err, sizeof(err), "%s: error during read", name); + perror(err); + } else if (feof(fp)) { + std::cerr << name << ": file is shorter than its stat size (" << bytes_read << " v " + << st.st_size << ")" << std::endl; + } else { + std::cerr << name << ": fread stopped returning data" << std::endl; + } + abort(); + } + bytes_read += this_round; + } + + fclose(fp); + return data; +} + int main(int argc, char** argv) { assert(argc == 6 && "Usage: test "); auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); - struct stat st; char* json_data; char* params_data; - uint64_t params_size; - - FILE* fp = fopen(argv[4], "rb"); - stat(argv[4], &st); - json_data = (char*)malloc(st.st_size); - fread(json_data, st.st_size, 1, fp); - fclose(fp); + size_t params_size; - fp = fopen(argv[5], "rb"); - stat(argv[5], &st); - params_data = (char*)malloc(st.st_size); - fread(params_data, st.st_size, 1, fp); - params_size = st.st_size; - fclose(fp); + json_data = read_all_or_die("json_data", argv[4], nullptr); + params_data = read_all_or_die("params_data", argv[5], ¶ms_size); struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); @@ -65,15 +96,19 @@ int main(int argc, char** argv) { json_data, params_data, params_size); gettimeofday(&t1, 0); - float input_storage[10 * 5]; - fp = fopen(argv[2], "rb"); - fread(input_storage, 10 * 5, 4, fp); - fclose(fp); - - float result_storage[10 * 5]; - fp = fopen(argv[3], "rb"); - fread(result_storage, 10 * 5, 4, fp); - fclose(fp); + size_t input_storage_size; + float* input_storage = + reinterpret_cast(read_all_or_die("input_storage", argv[2], &input_storage_size)); + size_t result_storage_size; + float* result_storage = + reinterpret_cast(read_all_or_die("result_storage", argv[3], &result_storage_size)); + + size_t expected_size = 10 * 5 * sizeof(float); + if (input_storage_size != expected_size || result_storage_size != expected_size) { + std::cerr << "wrong input or result storage size (want " << expected_size + << "input_storage_size=" << input_storage_size + << "; result_storage_size=" << result_storage_size << std::endl; + } std::vector input_shape = {10, 5}; DLTensor input; diff --git a/apps/bundle_deploy/test_static.c b/apps/bundle_deploy/test_static.c index fca08d18da74..773ba62140d4 100644 --- a/apps/bundle_deploy/test_static.c +++ b/apps/bundle_deploy/test_static.c @@ -51,7 +51,7 @@ int main(int argc, char** argv) { struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto* handle = tvm_runtime_create(json_data, params_data, params_size); + auto* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]); gettimeofday(&t1, 0); float input_storage[10 * 5]; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 7f660be54732..a661efad58f0 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -150,6 +150,12 @@ class TargetKindAttrMap : public AttrRegistryMap { explicit TargetKindAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} }; +/*! \brief Value used with --runtime in target specs to indicate the C++ runtime. */ +static constexpr const char* kTvmRuntimeCpp = "c++"; + +/*! \brief Value used with --runtime in target specs to indicate the C runtime. */ +static constexpr const char* kTvmRuntimeCrt = "c"; + /*! * \brief Helper structure to register TargetKind * \sa TVM_REGISTER_TARGET_KIND diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index 74c92696a3d9..cf11507d70fa 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -31,8 +31,8 @@ RANLIB ?= ${PREFIX}ranlib QUIET ?= @ CFLAGS += -isystem "${TVM_INCLUDE_DIR}" -isystem "${DLPACK_INCLUDE_DIR}" -I include -I $(dir ${CRT_CONFIG}) -CFLAGS += -Werror -g -LDFLAGS += -Werror -g +CFLAGS += -Werror -g $(EXTRA_CFLAGS) +LDFLAGS += -Werror -g $(EXTRA_LDFLAGS) ${BUILD_DIR}/%.o: %.c ${QUIET}mkdir -p $(dir $@) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 12b74db42f2a..c1e994ffb8c7 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -259,9 +259,7 @@ int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_impo int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value, int* ret_type_codes) { - int function_index; TVMModuleHandle mod; - int module_index; const char* name; int to_return; int query_imports; diff --git a/src/runtime/crt/common/memory.c b/src/runtime/crt/common/memory.c index 4ede8ca45783..7a634b9a7033 100644 --- a/src/runtime/crt/common/memory.c +++ b/src/runtime/crt/common/memory.c @@ -278,8 +278,6 @@ void MemoryManagerCreate(MemoryManager* manager, uint8_t* memory_pool, size_t bytes_needed_per_page = page_size_bytes + metadata_bytes_per_page; size_t num_pages = memory_pool_size_bytes / bytes_needed_per_page; - size_t metadata_pages_bytes = ROUND_UP(metadata_bytes_per_page * num_pages, page_size_bytes); - size_t metadata_num_pages = metadata_pages_bytes >> page_size_bytes_log2; uint8_t* metadata_cursor = memory_pool + (num_pages << page_size_bytes_log2); manager->ptable.memory_pool = memory_pool; diff --git a/src/runtime/crt/common/packed_func.c b/src/runtime/crt/common/packed_func.c index 8fc94529b839..b5a3121357d5 100644 --- a/src/runtime/crt/common/packed_func.c +++ b/src/runtime/crt/common/packed_func.c @@ -121,6 +121,8 @@ TVMArgs TVMArgs_Create(TVMValue* values, uint32_t* tcodes, uint32_t values_count } int TVMPackedFunc_Call(TVMPackedFunc* pf) { + pf->ret_value.values_count = 1; + pf->ret_value.tcodes[0] = kTVMNullptr; return TVMFuncCall(pf->fexec, pf->args.values, pf->args.tcodes, pf->args.values_count, pf->ret_value.values, pf->ret_value.tcodes); } diff --git a/src/runtime/crt/graph_runtime/graph_runtime.c b/src/runtime/crt/graph_runtime/graph_runtime.c index cf56a5c8b773..d93e909b8825 100644 --- a/src/runtime/crt/graph_runtime/graph_runtime.c +++ b/src/runtime/crt/graph_runtime/graph_runtime.c @@ -82,8 +82,12 @@ void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode* node, JSONReader* reader memset(key, 0, sizeof(key)); memset(value, 0, sizeof(value)); reader->BeginObject(reader); - while (reader->NextObjectItem(reader, key)) { - reader->ReadString(reader, value); + while (reader->NextObjectItem(reader, key, sizeof(key))) { + int status = reader->ReadString(reader, value, sizeof(value)); + if (status != 0) { + fprintf(stderr, "error reading value for key: %s\n", key); + break; + } if (!strcmp(key, "func_name")) { snprintf(param->func_name, sizeof(value), "%s", value); bitmask |= 1; @@ -110,12 +114,20 @@ int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode* node, JSONReader* reader) { reader->BeginObject(reader); int bitmask = 0; char key[20]; - while (reader->NextObjectItem(reader, key)) { + while (reader->NextObjectItem(reader, key, sizeof(key))) { if (!strcmp(key, "op")) { - reader->ReadString(reader, node->op_type); + status = reader->ReadString(reader, node->op_type, sizeof(node->op_type)); + if (status != 0) { + fprintf(stderr, "error reading op\n"); + break; + } bitmask |= 1; } else if (!strcmp(key, "name")) { - reader->ReadString(reader, node->name); + status = reader->ReadString(reader, node->name, sizeof(node->name)); + if (status != 0) { + fprintf(stderr, "error reading name\n"); + break; + } bitmask |= 2; } else if (!strcmp(key, "inputs")) { size_t count = node->inputs_count; @@ -200,7 +212,7 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re uint32_t shape_count = 0; uint32_t device_index_count = 0; reader->BeginObject(reader); - while (reader->NextObjectItem(reader, key)) { + while (reader->NextObjectItem(reader, key, sizeof(key))) { if (!strcmp(key, "dltype")) { reader->BeginArray(reader); if (!(reader->NextArrayItem(reader))) { @@ -208,7 +220,11 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re status = -1; break; } - reader->ReadString(reader, type); + status = reader->ReadString(reader, type, sizeof(type)); + if (status != 0) { + fprintf(stderr, "error reading dltype type\n"); + break; + } if (strcmp(type, "list_str")) { fprintf(stderr, "Invalid json format\n"); status = -1; @@ -222,7 +238,12 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { attr->dltype = vrealloc(attr->dltype, TVM_CRT_STRLEN_DLTYPE * (dltype_count + 1)); - reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE); + status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE, + TVM_CRT_STRLEN_DLTYPE); + if (status != 0) { + fprintf(stderr, "error reading dltype array item"); + break; + } dltype_count++; } attr->dltype_count = dltype_count; @@ -240,7 +261,10 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re status = -1; break; } - reader->ReadString(reader, type); + status = reader->ReadString(reader, type, sizeof(type)); + if (status != 0) { + fprintf(stderr, "error reading device_index array item"); + } if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); status = -1; @@ -270,7 +294,11 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re status = -1; break; } - reader->ReadString(reader, type); + status = reader->ReadString(reader, type, sizeof(type)); + if (status != 0) { + fprintf(stderr, "error reading shape array item\n"); + break; + } if (strcmp(type, "list_shape")) { fprintf(stderr, "Invalid json format\n"); status = -1; @@ -319,7 +347,11 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re status = -1; break; } - reader->ReadString(reader, type); + status = reader->ReadString(reader, type, sizeof(type)); + if (status != 0) { + fprintf(stderr, "error reading device_index array item"); + break; + } if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); status = -1; @@ -348,7 +380,7 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* re status = -1; break; } - reader->ReadString(reader, type); + reader->ReadString(reader, type, sizeof(type)); if (!strcmp(type, "list_int")) { if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); @@ -425,7 +457,7 @@ int TVMGraphRuntime_Load(TVMGraphRuntime* runtime, JSONReader* reader) { reader->BeginObject(reader); int bitmask = 0; char key[20]; - while (reader->NextObjectItem(reader, key)) { + while (reader->NextObjectItem(reader, key, sizeof(key))) { if (!strcmp(key, "nodes")) { reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { diff --git a/src/runtime/crt/graph_runtime/load_json.c b/src/runtime/crt/graph_runtime/load_json.c index e4c71fd55095..6f512b5347e7 100644 --- a/src/runtime/crt/graph_runtime/load_json.c +++ b/src/runtime/crt/graph_runtime/load_json.c @@ -158,60 +158,55 @@ char JSONReader_PeekNextNonSpace(JSONReader* reader) { /*! * \brief Parse next JSON string. * \param out_str the output string. + * \param out_str_size Number of bytes available to write starting from out_str. Includes + * terminating \0. * \throw dmlc::Error when next token is not string */ -int JSONReader_ReadString(JSONReader* reader, char* out_str) { +int JSONReader_ReadString(JSONReader* reader, char* out_str, size_t out_str_size) { int status = 0; char ch = reader->NextNonSpace(reader); - char output[128]; - uint32_t output_counter = 0; - memset(output, 0, 128); - while (1) { + size_t output_counter = 0; + while (output_counter < out_str_size) { ch = reader->NextChar(reader); if (ch == '\\') { char sch = reader->NextChar(reader); switch (sch) { case 'r': - snprintf(output + strlen(output), sizeof(output), "\r"); + out_str[output_counter++] = '\r'; break; case 'n': - snprintf(output + strlen(output), sizeof(output), "\n"); + out_str[output_counter++] = '\n'; break; case '\\': - snprintf(output + strlen(output), sizeof(output), "\\"); + out_str[output_counter++] = '\\'; break; case 't': - snprintf(output + strlen(output), sizeof(output), "\t"); + out_str[output_counter++] = '\t'; break; case '\"': - snprintf(output + strlen(output), sizeof(output), "\""); + out_str[output_counter++] = '\"'; break; default: fprintf(stderr, "unknown string escape %c\n", sch); + break; } } else { if (ch == '\"') { break; } - if (strlen(output) >= 127) { - fprintf(stderr, "Error: detected buffer overflow.\n"); - status = -1; - break; - } - strncat(output, &ch, 1); - output_counter++; - if (output_counter >= 127) { - fprintf(stderr, "Error: string size greater than 128.\n"); - status = -1; - break; - } + out_str[output_counter++] = ch; + } + if (output_counter == out_str_size - 1) { + fprintf(stderr, "Error: string size greater than buffer size (%zu).\n", out_str_size); + break; } if (ch == EOF || ch == '\r' || ch == '\n') { fprintf(stderr, "Error at line X, Expect \'\"\' but reach end of line\n"); - status = -1; + break; } } - snprintf(out_str, sizeof(output), "%s", output); + + out_str[output_counter] = 0; return status; } @@ -262,9 +257,10 @@ void JSONReader_BeginObject(JSONReader* reader) { * If this call is successful, user can proceed to call * reader->Read to read in the value. * \param out_key the key to the next object. + * \param out_key_size number of bytes available to write at out_key, including terminating \0. * \return true if the read is successful, false if we are at end of the object. */ -uint8_t JSONReader_NextObjectItem(JSONReader* reader, char* out_key) { +uint8_t JSONReader_NextObjectItem(JSONReader* reader, char* out_key, size_t out_key_size) { uint8_t next = 1; Seq* scope_counter_ = reader->scope_counter_; if (scope_counter_->back(scope_counter_)[0] != 0) { @@ -290,7 +286,11 @@ uint8_t JSONReader_NextObjectItem(JSONReader* reader, char* out_key) { return 0; } else { scope_counter_->back(scope_counter_)[0] += 1; - reader->ReadString(reader, out_key); + int err = reader->ReadString(reader, out_key, out_key_size); + if (err != 0) { + fprintf(stderr, "error reading key"); + return 0; + } int ch = reader->NextNonSpace(reader); if (ch != ':') { fprintf(stderr, "Error at line X, Expect \':\' but get \'%c\'\n", ch); diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/load_json.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/load_json.h index 39c2576b2282..5cd84c8d1a77 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/load_json.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_runtime/load_json.h @@ -75,10 +75,10 @@ typedef struct JSONReader { char (*PeekNextNonSpace)(struct JSONReader* reader); int (*ReadUnsignedInteger)(struct JSONReader* reader, unsigned int* out_value); int (*ReadInteger)(struct JSONReader* reader, int64_t* out_value); - int (*ReadString)(struct JSONReader* reader, char* out_value); + int (*ReadString)(struct JSONReader* reader, char* out_str, size_t out_str_size); void (*BeginArray)(struct JSONReader* reader); void (*BeginObject)(struct JSONReader* reader); - uint8_t (*NextObjectItem)(struct JSONReader* reader, char* out_key); + uint8_t (*NextObjectItem)(struct JSONReader* reader, char* out_key, size_t out_key_size); uint8_t (*NextArrayItem)(struct JSONReader* reader); } JSONReader; diff --git a/src/support/str_escape.h b/src/support/str_escape.h index 65eec682086e..4a927340467b 100644 --- a/src/support/str_escape.h +++ b/src/support/str_escape.h @@ -35,9 +35,11 @@ namespace support { * \brief Create a stream with escape. * \param data The data * \param size The size of the string. + * \param use_octal_escape True to use octal escapes instead of hex. If producing C + * strings, use octal escapes to avoid ambiguously-long hex escapes. * \return the Result string. */ -inline std::string StrEscape(const char* data, size_t size) { +inline std::string StrEscape(const char* data, size_t size, bool use_octal_escape = false) { std::ostringstream stream; for (size_t i = 0; i < size; ++i) { unsigned char c = data[i]; @@ -62,8 +64,12 @@ inline std::string StrEscape(const char* data, size_t size) { stream << 'n'; break; default: - const char* hex_digits = "0123456789ABCDEF"; - stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf]; + if (use_octal_escape) { + stream << '0' + ((c >> 6) & 0x03) << '0' + ((c >> 3) & 0x07) << '0' + (c & 0x03); + } else { + const char* hex_digits = "0123456789ABCDEF"; + stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf]; + } } } } diff --git a/src/target/func_registry_generator.cc b/src/target/func_registry_generator.cc new file mode 100644 index 000000000000..402d0f8538e2 --- /dev/null +++ b/src/target/func_registry_generator.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Defines functions that generate FuncRegistry structs for C runtime. + * \file func_registry_generator.cc + */ + +#include "func_registry_generator.h" + +#include + +namespace tvm { +namespace target { + +std::string GenerateFuncRegistryNames(const std::vector& function_names) { + std::stringstream ss; + ss << (unsigned char)(function_names.size()); + for (auto f : function_names) { + ss << f << '\0'; + } + + return ss.str(); +} + +} // namespace target +} // namespace tvm diff --git a/src/target/func_registry_generator.h b/src/target/func_registry_generator.h new file mode 100644 index 000000000000..362fca872eb4 --- /dev/null +++ b/src/target/func_registry_generator.h @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Defines functions that generate FuncRegistry structs for C runtime. + * \file func_registry_generator.h + */ +#ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ +#define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ + +#include +#include + +namespace tvm { +namespace target { + +std::string GenerateFuncRegistryNames(const std::vector& function_names); + +} // namespace target +} // namespace tvm + +#endif // TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 93c94cfa4389..758a4f6be7e6 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -250,7 +250,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenAMDGPU()); - cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false); + cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false); for (auto kv : mod->functions) { CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 41fa3c5a3fc1..7eea61e43e8b 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -31,12 +31,15 @@ #include #include +#include "../func_registry_generator.h" + namespace tvm { namespace codegen { void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { - CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup); + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, + bool target_c_runtime) { + CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); func_handle_map_.clear(); @@ -51,6 +54,14 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, t_tvm_shape_index_->getPointerTo(), t_int64_}); t_tvm_value_ = llvm::StructType::create({t_float64_}); t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_}); + ftype_tvm_backend_packed_c_func_ = llvm::FunctionType::get( + t_int_, + {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, + t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_void_p_}, + false); + t_tvm_crt_func_registry_ = llvm::StructType::create( + {t_char_->getPointerTo(), ftype_tvm_backend_packed_c_func_->getPointerTo()}); + t_tvm_crt_module_ = llvm::StructType::create({t_tvm_crt_func_registry_->getPointerTo()}); ftype_tvm_parallel_lambda_ = llvm::FunctionType::get( t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false); md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_); @@ -75,7 +86,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_}, false); // initialize TVM runtime API - if (system_lib) { + if (system_lib && !target_c_runtime) { // We will need this in environment for backward registration. f_tvm_register_system_symbol_ = llvm::Function::Create( llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false), @@ -100,6 +111,8 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, "TVMBackendParallelBarrier", module_.get()); } this->InitGlobalContext(dynamic_lookup); + target_c_runtime_ = target_c_runtime; + is_system_lib_ = system_lib; } void CodeGenCPU::AddFunction(const PrimFunc& f) { @@ -109,8 +122,13 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( - std::make_pair(global_symbol.value().operator std::string(), - builder_->CreatePointerCast(function_, t_void_p_))); + std::make_pair(global_symbol.value().operator std::string(), function_)); + } else if (target_c_runtime_) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; + registry_functions_.emplace_back( + std::make_pair(global_symbol.value().operator std::string(), function_)); } AddDebugInformation(function_); } @@ -357,7 +375,7 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { // Module context gv_mod_ctx_ = InitContextPtr(t_void_p_, tvm::runtime::symbol::tvm_module_ctx); // Register back the locations. - if (f_tvm_register_system_symbol_ != nullptr) { + if (f_tvm_register_system_symbol_ != nullptr && !target_c_runtime_) { export_system_symbols_.emplace_back( std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_)); } else { @@ -756,7 +774,46 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { } void CodeGenCPU::AddStartupFunction() { - if (export_system_symbols_.size() != 0) { + if (registry_functions_.size() != 0) { + CHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; + std::vector symbols; + std::vector funcs; + for (auto sym : registry_functions_) { + symbols.emplace_back(sym.first); + funcs.emplace_back(llvm::ConstantExpr::getBitCast( + sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo())); + } + llvm::DataLayout layout(module_.get()); + llvm::ArrayType* t_tvm_crt_func_ptrs = + llvm::ArrayType::get(ftype_tvm_backend_packed_c_func_->getPointerTo(), funcs.size()); + llvm::GlobalVariable* func_registry_ptrs = new llvm::GlobalVariable( + *module_, t_tvm_crt_func_ptrs, true, llvm::GlobalValue::InternalLinkage, + llvm::ConstantArray::get(t_tvm_crt_func_ptrs, funcs), "_tvm_func_registry_ptrs"); + uint64_t align = layout.getTypeAllocSize(ftype_tvm_backend_packed_c_func_->getPointerTo()); +#if TVM_LLVM_VERSION >= 100 + func_registry_ptrs->setAlignment(llvm::Align(align)); +#else + func_registry_ptrs->setAlignment(align); +#endif + llvm::GlobalVariable* func_registry = new llvm::GlobalVariable( + *module_, t_tvm_crt_func_registry_, true, llvm::GlobalVariable::InternalLinkage, + llvm::ConstantStruct::get( + t_tvm_crt_func_registry_, + {GetConstString(::tvm::target::GenerateFuncRegistryNames(symbols)), + func_registry_ptrs}), + "_tvm_crt_func_registry"); + llvm::GlobalVariable* module = new llvm::GlobalVariable( + *module_, t_tvm_crt_module_, true, llvm::GlobalValue::InternalLinkage, + llvm::ConstantStruct::get(t_tvm_crt_module_, {func_registry}), "_tvm_crt_module"); + + // Now build TVMSystemLibEntryPoint. + llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_p_, {}, false); + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + "TVMSystemLibEntryPoint", module_.get()); + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(entry_point_entry); + builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); + } else { llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false); function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, "__tvm_module_startup", module_.get()); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index fdeab4130782..fc46dc53ce15 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -39,7 +39,7 @@ namespace codegen { class CodeGenCPU : public CodeGenLLVM { public: void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup) override; + bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -62,6 +62,11 @@ class CodeGenCPU : public CodeGenLLVM { llvm::StructType* t_tvm_array_{nullptr}; llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_parallel_group_env_{nullptr}; + + llvm::FunctionType* ftype_tvm_backend_packed_c_func_{nullptr}; + llvm::StructType* t_tvm_crt_func_registry_{nullptr}; + llvm::StructType* t_tvm_crt_module_{nullptr}; + llvm::FunctionType* ftype_tvm_parallel_lambda_{nullptr}; llvm::FunctionType* ftype_tvm_func_call_{nullptr}; llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr}; @@ -136,9 +141,13 @@ class CodeGenCPU : public CodeGenLLVM { // global to packed function handle std::unordered_map func_handle_map_; // List of symbols to be exported to TVM system lib. - std::vector > export_system_symbols_; + std::vector> export_system_symbols_; + // List of functions to be registered in the FuncRegistry, if generated. + std::vector> registry_functions_; // internal debug information, to be populated by std::unique_ptr dbg_info_; + bool target_c_runtime_; + bool is_system_lib_; // Get the DWARF type corresponding to the LLVM type |ty|. The current API in practice only // generates |int32|, and |int8*|. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 225d22599554..4ffd6b25c237 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -49,7 +49,8 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { } void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, + bool target_c_runtime) { InitializeLLVM(); ctx_ = ctx; builder_.reset(new IRBuilder(*ctx_)); @@ -623,7 +624,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } } -llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { +llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index ce5baba7ea19..351d7d90b4fe 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -73,9 +73,11 @@ class CodeGenLLVM : public ExprFunctor, * \param system_lib Whether to insert system library registration. * \param dynamic_lookup Whether dynamically lookup runtime function * or use the runtime function table passed by caller. + * \param target_c_runtime If true, generate a module to be executed by the C runtime. In practice + * this option influences whether global ctors are used. */ virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup); + bool system_lib, bool dynamic_lookup, bool target_c_runtime); /*! * \brief Compile and add function f to the current module. * \param f The function to be added. @@ -246,7 +248,7 @@ class CodeGenLLVM : public ExprFunctor, void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits); // Get constant string - llvm::Value* GetConstString(const std::string& str); + llvm::Constant* GetConstString(const std::string& str); // do a scalarize call with f llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, const std::vector& args); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 71c8e78030c2..e2690b96a106 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -267,7 +267,7 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenNVPTX()); - cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false); + cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false); for (auto kv : mod->functions) { CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 9e585d7b4492..de2dadf9bb16 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -189,10 +189,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { return ""; } - void Init(const IRModule& mod, std::string target) { + void Init(const IRModule& mod, std::string target_str) { InitializeLLVM(); - tm_ = GetLLVMTargetMachine(target); - bool system_lib = (target.find("-system-lib") != std::string::npos); + tm_ = GetLLVMTargetMachine(target_str); + auto target = Target::Create(target_str); + bool system_lib = target->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = (target->GetAttr("runtime").value_or("") == kTvmRuntimeCrt); ctx_ = std::make_shared(); std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); @@ -211,7 +213,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK_NE(funcs.size(), 0U); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. - cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib); + cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); for (const auto& f : funcs) { cg->AddFunction(f); @@ -222,7 +224,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { } module_ = cg->Finish(); - module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, target)); + module_->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*ctx_, target_str)); module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); @@ -235,7 +238,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) << "LLVM module verification failed with the following errors: \n" << verify_errors.str(); - target_ = target; + target_ = target_str; mptr_ = module_.get(); } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 839962a8c733..3f1e3aaef933 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,12 +22,15 @@ */ #include "codegen_c_host.h" +#include #include #include #include +#include "../../support/str_escape.h" #include "../build_common.h" +#include "../func_registry_generator.h" namespace tvm { namespace codegen { @@ -43,6 +46,15 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) { CodeGenC::Init(output_ssa); } +void CodeGenCHost::AddFunction(const PrimFunc& f) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; + function_names_.emplace_back(global_symbol.value()); + + CodeGenC::AddFunction(f); +} + void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*) stream << "#ifdef __cplusplus\n" << "extern \"C\"\n" @@ -263,11 +275,34 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, << "? (" << a_id << ") : (" << b_id << "))"; } -runtime::Module BuildCHost(IRModule mod) { +void CodeGenCHost::GenerateFuncRegistry() { + decl_stream << "#include \n"; + stream << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n"; + for (auto f : function_names_) { + stream << " " << f << ",\n"; + } + auto registry = target::GenerateFuncRegistryNames(function_names_); + stream << "static const TVMFuncRegistry _tvm_func_registry = {\n" + << " \"" << ::tvm::support::StrEscape(registry.data(), registry.size(), true) << "\"," + << " _tvm_func_array,\n" + << "};\n"; +} + +void CodeGenCHost::GenerateCrtSystemLib() { + stream << "static const TVMModule _tvm_system_lib = {\n" + << " &system_lib_registry,\n" + << "};\n" + << "const TVMModule* TVMSystemLibEntryPoint(void) {\n" + << " return &system_lib;\n" + << "}\n"; +} + +runtime::Module BuildCHost(IRModule mod, const std::string& target_str) { using tvm::runtime::Registry; bool output_ssa = false; bool emit_asserts = false; CodeGenCHost cg; + auto target = Target::Create(target_str); cg.Init(output_ssa, emit_asserts); for (auto kv : mod->functions) { @@ -276,12 +311,19 @@ runtime::Module BuildCHost(IRModule mod) { cg.AddFunction(f); } + if (target->GetAttr("system-lib").value_or(Bool(false))) { + CHECK_EQ(target->GetAttr("runtime").value_or(""), "c") + << "c target only supports generating C runtime SystemLibs"; + cg.GenerateFuncRegistry(); + cg.GenerateCrtSystemLib(); + } + std::string code = cg.Finish(); return CSourceModuleCreate(code, "c"); } TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildCHost(args[0]); + *rv = BuildCHost(args[0], args[1]); }); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 94a76faabd78..66ac4ddd99d3 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -26,6 +26,7 @@ #include #include +#include #include "codegen_c.h" #include "tvm/target/codegen.h" @@ -39,6 +40,8 @@ class CodeGenCHost final : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts); + void AddFunction(const PrimFunc& f); + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintFuncPrefix() final; // NOLINT(*) void PrintFinalReturn() final; // NOLINT(*) @@ -53,10 +56,18 @@ class CodeGenCHost final : public CodeGenC { void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) + /*! \brief Generate C runtime FuncRegistry global constant. */ + void GenerateFuncRegistry(); + + /*! \brief Generate C runtime SystemLib entry point. */ + void GenerateCrtSystemLib(); + private: std::string module_name_; /* \brief tracks declared global variables which live despite GetUniqueName */ std::set declared_globals_; + /* \brief names of the functions declared in this module */ + std::vector function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 0bef651c30cf..e6f7c5cdec13 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -356,6 +356,7 @@ TVM_REGISTER_TARGET_KIND("llvm") .add_attr_option("device") .add_attr_option("model") .add_attr_option("system-lib") + .add_attr_option("runtime") .add_attr_option("mcpu") .add_attr_option>("mattr") .add_attr_option("mtriple") @@ -369,18 +370,10 @@ TVM_REGISTER_TARGET_KIND("c") .add_attr_option("device") .add_attr_option("model") .add_attr_option("system-lib") + .add_attr_option("runtime") .set_default_keys({"cpu"}) .set_device_type(kDLCPU); -TVM_REGISTER_TARGET_KIND("micro_dev") - .add_attr_option>("keys") - .add_attr_option>("libs") - .add_attr_option("device") - .add_attr_option("model") - .add_attr_option("system-lib") - .set_default_keys({"micro_dev"}) - .set_device_type(kDLMicroDev); - TVM_REGISTER_TARGET_KIND("cuda") .add_attr_option>("keys") .add_attr_option>("libs") diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py index eb137a9c8cf9..1983def99cf5 100644 --- a/tests/python/unittest/test_runtime_micro.py +++ b/tests/python/unittest/test_runtime_micro.py @@ -28,7 +28,7 @@ # # Use the host emulated micro device. DEV_CONFIG_A = micro.device.host.generate_config() DEV_CONFIG_B = micro.device.host.generate_config() -TARGET = 'micro_dev' +TARGET = 'c --runtime=c' def relay_micro_build(func, dev_config, params=None): """Create a graph runtime module with a micro device context from a Relay function. diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 519d18786e28..d690364f5c5d 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -784,6 +784,16 @@ def dotest(do_vectorize): dotest(True) dotest(False) +def test_llvm_crt_static_lib(): + A = te.placeholder((32, ), dtype='bfloat16') + B = te.placeholder((32, ), dtype='bfloat16') + d = te.compute((32, ), lambda x: A[x] + B[x]) + sch = te.create_schedule(d.op) + module = tvm.build(sch, [A, B, d], target=tvm.target.create('llvm --system-lib --runtime=c')) + print(module.get_source()) + module.save('test.o') + + if __name__ == "__main__": test_multiple_func() test_llvm_large_uintimm() @@ -807,3 +817,4 @@ def dotest(do_vectorize): test_dwarf_debug_information() test_llvm_shuffle() test_llvm_bf16() + test_llvm_crt_static_lib() diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index f7539d6a55fd..d61895c45973 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -36,7 +36,7 @@ make cython3 # Test MISRA-C runtime cd apps/bundle_deploy rm -rf build -# make test_dynamic test_static +make test_dynamic test_static cd ../.. # Test extern package From c1eb31566ac7321809f4b9734df97edf378573f6 Mon Sep 17 00:00:00 2001 From: shiwenloong <52487098+shiwenloong@users.noreply.github.com> Date: Fri, 7 Aug 2020 08:55:46 +0800 Subject: [PATCH 3/5] [PYTORCH]Std op without specified dimensions support (#6226) --- python/tvm/relay/frontend/pytorch.py | 11 ++++++++--- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3dfdb2f70e7f..bbc684ea8a4c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1253,9 +1253,14 @@ def _impl(inputs, input_types): def _std(): def _impl(inputs, input_types): data = inputs[0] - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[3]) - unbiased = bool(inputs[2]) + if len(inputs) == 2: + axis = None + keepdims = False + unbiased = bool(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e370cd502b59..3c9dfb13fc4c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1869,12 +1869,17 @@ class Std5(Module): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + class Std6(Module): + def forward(self, *args): + return args[0].std(unbiased=False) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) verify_model(Std3().float().eval(), input_data=input_data) verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) + verify_model(Std6().float().eval(), input_data=input_data) def test_forward_variance(): From 87f90107846841eba41409d65e8a77c82c033bf4 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Fri, 7 Aug 2020 06:27:20 +0530 Subject: [PATCH 4/5] [ONNX]Mod operator, bug fix (#6160) * Onnx mod, bug fix * Added comment for the mod/floor_mod behaviour difference between numpy & relay --- python/tvm/relay/frontend/onnx.py | 7 +++++- tests/python/frontend/onnx/test_forward.py | 29 ++++++++++------------ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1568c975dbe0..74626d49a9dd 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -530,10 +530,15 @@ class Mod(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs)) - if attr['fmod'] == 1: + + # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod. + # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment. + # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod + if attr['fmod'] == 0: op_name = "floor_mod" else: op_name = "mod" + return AttrCvt(op_name)(inputs, {}, params) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 56ea96d7b7e5..14b827c79248 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2374,17 +2374,11 @@ def test_pooling(): auto_pad='SAME_UPPER') -def verify_mod(x_shape, y_shape, fmod, dtype='float32'): - x_np = np.random.uniform(size=x_shape).astype(dtype) - y_np = np.random.uniform(size=y_shape).astype(dtype) +def verify_mod(x_shape, y_shape, fmod, out_shape, dtype='float32'): + x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype) + y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype) y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error - if fmod: - np_out = np.fmod(x_np, y_np) - else: - np_out = np.mod(x_np, y_np) - - out_shape = np_out.shape mod_node = helper.make_node("Mod", inputs=["x", "y"], outputs=["z"], @@ -2401,22 +2395,25 @@ def verify_mod(x_shape, y_shape, fmod, dtype='float32'): onnx_dtype, list(out_shape))]) model = helper.make_model(graph, producer_name='mod_test') + onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0] + for target, ctx in ctx_list(): tvm_out = get_tvm_output( model, [x_np, y_np], target, ctx, out_shape) - tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) def test_mod(): # Mod - verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=0) - - verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, dtype="int32") + verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32") + verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=0, out_shape=(1, 32, 32, 32), dtype="int32") # fmod - verify_mod(x_shape=[1, 1, 32], y_shape=[1, 32, 32], fmod=1) - - verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, dtype="int32") + verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32") + verify_mod(x_shape=[1, 1, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) + verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 1, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) + verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32), dtype="int32") + verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32)) def verify_xor(x_shape, y_shape): From da27e6d9a466263a9a0025aba92086a8bf837edb Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Fri, 7 Aug 2020 04:08:35 +0100 Subject: [PATCH 5/5] Reshape with dynamic shape arg (#6208) Reshape operation updated to take shape from second operand. In case if shape is provided using second operand it can be a tensor now. --- python/tvm/relay/frontend/tflite.py | 35 +++++++++++++++----- tests/python/frontend/tflite/test_forward.py | 29 ++++++++++++---- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6e032b1efda8..fe28741e7c59 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -459,26 +459,43 @@ def convert_reshape(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) - assert input_tensors, "input tensors should not be empty" + assert len(input_tensors) in (1, 2), "input tensors should not be empty" + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "There should be only 1 output tensor" + input_tensor = input_tensors[0] input_tensor_idx = input_tensor.tensor_idx - assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions - op_options = op.BuiltinOptions() - reshape_options = ReshapeOptions() - reshape_options.Init(op_options.Bytes, op_options.Pos) - target_shape = reshape_options.NewShapeAsNumpy() + if len(input_tensors) == 2: + shape_tensor = input_tensors[1] + if self.has_expr(shape_tensor.tensor_idx): + target_shape = self.get_expr(shape_tensor.tensor_idx) + else: + target_shape = self.get_tensor_value(shape_tensor) + # convert to flattened list + from itertools import chain + try: + target_shape = list(chain(*target_shape)) + except TypeError: + target_shape = list(chain(target_shape)) + + else: + assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions + op_options = op.BuiltinOptions() + reshape_options = ReshapeOptions() + reshape_options.Init(op_options.Bytes, op_options.Pos) + target_shape = tuple(reshape_options.NewShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) # If the tensors are quantized, ensure that input/output qnn params are same. if input_tensor.qnn_params: - output_tensors = self.get_output_tensors(op) - assert len(output_tensors) == 1, "There should be only 1 output tensor" output_tensor = output_tensors[0] assert self.has_same_qnn_params(input_tensor, output_tensor), \ "TFLite reshape requires input and output scale and zero points to be equal" - out = _op.reshape(in_expr, newshape=tuple(target_shape)) + + out = _op.reshape(in_expr, newshape=target_shape) return out def _convert_resize(self, method, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 603eb1169624..30a663176a3f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -984,20 +984,35 @@ def test_forward_transpose_conv(): # Reshape # ------- -def _test_reshape(data, out_shape): +def _test_reshape(data, out_shape, wrap_shape): """ One iteration of reshape operation with given data and out shape """ with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out = array_ops.reshape(in_data, out_shape) - compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + out_shape = out_shape if not wrap_shape\ + else np.array(out_shape, dtype=np.int32) + + in_shape = out_shape if not wrap_shape\ + else array_ops.placeholder(shape=out_shape.shape,\ + dtype=out_shape.dtype,\ + name="Newshape") + + out = array_ops.reshape(in_data, in_shape) + + compare_tflite_with_tvm( + [data, out_shape] if wrap_shape else [data],\ + ['Placeholder:0', 'Newshape:0'] if wrap_shape else ['Placeholder:0'],\ + [in_data, in_shape] if wrap_shape else [in_data],\ + [out], + mode='vm') def test_forward_reshape(): - _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3]) - _test_reshape(np.arange(6), [-1, 2]) - _test_reshape(np.arange(6), [3, -1]) - _test_reshape(np.arange(6), [-1]) + for wrap in [True, False]: + _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3], wrap) + _test_reshape(np.arange(6), [-1, 2], wrap) + _test_reshape(np.arange(6), [3, -1], wrap) + _test_reshape(np.arange(6), [-1], wrap) #######################################################################