diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index a555ecd68a65..1cd1ed63db97 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -98,6 +98,8 @@ class IndexedGraph { array_view inputs; /*! \brief control flow dependencies to the node */ array_view control_deps; + /*! \brief weak reference to node */ + std::weak_ptr weak_ref; }; /*! \return number of nodes in the graph */ inline size_t num_nodes() const { diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 5f1fbcc0df56..8570c03d88f0 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -11,7 +11,8 @@ OPT_PASS_LEVEL = { "SimplifyInference": 0, "PrecomputePrune": 2, - "OpFusion": 1 + "OpFusion": 1, + "FoldScaleAxis": 3 } # List of optimization pass and level when switch on @@ -144,6 +145,10 @@ def optimize(graph, shape, dtype="float32"): if cfg.pass_enabled("SimplifyInference"): graph = graph_attr.set_shape_inputs(graph, shape) graph = graph.apply(["InferShape", "SimplifyInference"]) + + if cfg.pass_enabled("FoldScaleAxis"): + graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph.apply(["InferShape", "FoldScaleAxis"]) return graph @@ -291,5 +296,6 @@ def precompute_prune(graph, params): out_names = pre_graph.json_attr("output_names") if not pre_graph.symbol.list_output_names(): return graph, params - out_arrs = _run_graph(pre_graph, params) + with tvm.build_config(auto_unroll_max_step=0): + out_arrs = _run_graph(pre_graph, params) return graph, dict(zip(out_names, out_arrs)) diff --git a/nnvm/python/nnvm/testing/init.py b/nnvm/python/nnvm/testing/init.py index 3ce7e40ef87f..36ddcc955f7c 100644 --- a/nnvm/python/nnvm/testing/init.py +++ b/nnvm/python/nnvm/testing/init.py @@ -81,7 +81,6 @@ def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3): self.factor_type = factor_type self.magnitude = float(magnitude) - def _init_weight(self, name, arr): shape = arr.shape hw_scale = 1. diff --git a/nnvm/python/nnvm/testing/mobilenet.py b/nnvm/python/nnvm/testing/mobilenet.py index bebd7ddc6802..1d59df64ab4a 100644 --- a/nnvm/python/nnvm/testing/mobilenet.py +++ b/nnvm/python/nnvm/testing/mobilenet.py @@ -30,7 +30,8 @@ def separable_conv_block(data, name, depthwise_channels, # depthwise convolution + bn + relu conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides, - padding=padding, use_bias=False, layout="NCHW", name=name + "_depthwise_conv1") + padding=padding, use_bias=False, layout="NCHW", + name=name + "_depthwise_conv1") bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1") act1 = sym.relu(data=bn1, name=name + "_relu1") # pointwise convolution + bn + relu diff --git a/nnvm/python/nnvm/testing/utils.py b/nnvm/python/nnvm/testing/utils.py index fcc008e61fc3..d6c03fc1b745 100644 --- a/nnvm/python/nnvm/testing/utils.py +++ b/nnvm/python/nnvm/testing/utils.py @@ -46,7 +46,7 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224), input_shapes, _ = graph_util.infer_shape(g, data=data_shape) shape_dict = dict(zip(g.index.input_names, input_shapes)) np.random.seed(seed) - initializer = initializer if initializer else Xavier(magnitude=3) + initializer = initializer if initializer else Xavier() for k, v in shape_dict.items(): if k == "data": continue diff --git a/nnvm/src/README.md b/nnvm/src/README.md index da3584a73cb1..adae68105650 100644 --- a/nnvm/src/README.md +++ b/nnvm/src/README.md @@ -7,8 +7,7 @@ The following components are operator invariant. - core: NNVM core data structure - pass: NNVM pass -The following components are generic graph compiler for NNVM-TOP +The following components are generic NNVM compiler and defines tensor operator set -- top: NNVM-TOP core operator defs -- tvm: NNVM-TOP to TVM compiler toolchain -- runtime: NNVM-TOP runtime +- top: NNVM core tensor operators +- compiler: NNVM compiler toolchain diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index d31612f5a826..e345497e2512 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -58,7 +58,7 @@ class CompileEngine { return it->second->graph_func; } GraphFunc f = DoLower(key->graph, key->inputs, key->target, - schedule_op_key, schedule_op_attr); + schedule_op_key, schedule_op_attr); std::shared_ptr n = std::make_shared(); n->graph_func = f; n->use_count = 1; diff --git a/nnvm/src/compiler/fold_scale_axis.cc b/nnvm/src/compiler/fold_scale_axis.cc new file mode 100644 index 000000000000..eb0ea4b87292 --- /dev/null +++ b/nnvm/src/compiler/fold_scale_axis.cc @@ -0,0 +1,271 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file fold_scale_axis.cc + * \author Fold scaling parameter of axis into weight of conv/dense +*/ +#include +#include +#include +#include +#include +#include +#include "./pattern_util.h" +#include "./graph_transform.h" + +namespace nnvm { +namespace compiler { + +enum FoldScaleKind { + // No folding is applied + kNone, + // The folding decision is pending + kPending, + // The original operator that contains the scale. + kProvider, + // Pass through the scale to parent/child to the first axis. + kPassTroughFirst, + // The final conumer of axis scale using multiply + // Likely be a conv or dense operator. + kMulConsumer, + // The final conumer of axis scale using division + kDivConsumer +}; + +// Input fold information +struct FoldScaleInput { + uint32_t index; + int axis; +}; + +// The entry of folding chains on which +// we should perform folding on +struct FoldChainEntry { + // Entry kind + FoldScaleKind kind{kNone}; + // The output axis to be folded + int axis{0}; + // Source node in the fold chain + int source{0}; + // Following field only used by provider. + // The input index + int fold_input_index{1}; + // The scale entry + NodeEntry scale_entry; +}; + +// Try to pass axis scaling to backward, +// Given that we we know the status of current fold axis. +using FScaleAxisBackward = std::function< + FoldScaleKind(const NodeAttrs& attrs, + int axis, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector >* in_axis)>; + +// Detect if there is a scaling axis happening +bool DetectScaleAxis(const IndexedGraph& idx, + uint32_t nid, + const ShapeVector& shape_vec, + const std::vector& ref_count, + bool is_forward, + std::vector* chain) { + const IndexedGraph::Node& inode = idx[nid]; + static const Op* bcast_mul = Op::Get("broadcast_mul"); + static const Op* expand_dims = Op::Get("expand_dims"); + if (inode.source->op() != bcast_mul) return false; + const TShape& oshape = shape_vec[idx.entry_id(nid, 0)]; + CHECK_NE(oshape.ndim(), 0); + if (oshape.ndim() <= 1) return false; + for (int i = 0; i < 2; ++i) { + const IndexedGraph::NodeEntry& a = inode.inputs[i]; + const IndexedGraph::NodeEntry& b = inode.inputs[1 - i]; + std::pair axis = + MatchBroadcast1DAxis(oshape, shape_vec[idx.entry_id(a)]); + if (axis.first != -1 && + shape_vec[idx.entry_id(b)] == oshape) { + if (ref_count[a.node_id] != 1) return false; + if (is_forward && ref_count[nid] != 1) return false; + if (!is_forward && ref_count[b.node_id] != 1) return false; + const IndexedGraph::Node& anode = idx[a.node_id]; + // mark the current entry. + FoldChainEntry& e = (*chain)[nid]; + if (anode.source->is_variable()) { + e.fold_input_index = 1 - i; + e.scale_entry = inode.source->inputs[1 - i]; + } else if (anode.source->op() == expand_dims && + shape_vec[idx.entry_id(anode.source->inputs[0])].ndim() == 1) { + e.fold_input_index = 1 - i; + e.scale_entry = anode.source->inputs[0]; + } else { + return false; + } + e.axis = axis.first; + e.kind = kPending; + e.source = nid; + if (!is_forward) { + // pass message to another input + FoldChainEntry& enext = (*chain)[b.node_id]; + enext.axis = e.axis; + enext.kind = kPending; + enext.source = nid; + } + return true; + } + } + return false; +} + +Graph FoldScaleAxis(Graph src) { + // Operator pattern + static auto& fbackward = + nnvm::Op::GetAttr("FScaleAxisBackward"); + const IndexedGraph& idx = src.indexed_graph(); + const ShapeVector& shape_vec = src.GetAttr("shape"); + std::vector ref_count = GetNodeRefCounts(idx); + std::vector bwd_chain(idx.num_nodes()); + // shape hint for the inference. + std::vector in_shape, out_shape; + // perform backward folding. + for (uint32_t i = idx.num_nodes(); i != 0; --i) { + uint32_t nid = i - 1; + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + if (DetectScaleAxis(idx, nid, shape_vec, + ref_count, false, &bwd_chain)) continue; + if (bwd_chain[nid].kind != kPending) continue; + if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) { + bwd_chain[nid].kind = kNone; continue; + } + // get input shape and output shape. + in_shape.clear(); out_shape.clear(); + for (const IndexedGraph::NodeEntry& e : inode.inputs) { + in_shape.push_back(shape_vec[idx.entry_id(e)]); + } + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + out_shape.push_back(shape_vec[idx.entry_id(nid, i)]); + } + std::vector > in_axis; + FoldScaleKind kind = + fbackward[inode.source->op()]( + inode.source->attrs, bwd_chain[nid].axis, + in_shape, out_shape, &in_axis); + bwd_chain[nid].kind = kind; + if (kind == kNone) continue; + CHECK_GE(in_axis.size(), 1U); + CHECK(kind == kPassTroughFirst || kMulConsumer); + // propagate back. + bool can_prop = true; + for (size_t i = 0; i < in_axis.size(); ++i) { + const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[0].first]; + if (ref_count[e.node_id] != 1 || + idx[e.node_id].source->num_outputs() != 1) { + can_prop = false; break; + } + } + if (!can_prop) continue; + for (size_t i = 0; i < in_axis.size(); ++i) { + const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[i].first]; + if (kind == kPassTroughFirst && i == 0) { + bwd_chain[e.node_id].kind = kPending; + } else { + bwd_chain[nid].kind = kNone; + bwd_chain[e.node_id].kind = kMulConsumer; + } + bwd_chain[e.node_id].axis = in_axis[i].second; + bwd_chain[e.node_id].source = bwd_chain[nid].source; + } + if (kind == kMulConsumer) { + bwd_chain[bwd_chain[nid].source].kind = kProvider; + } + } + auto transform = [&](uint32_t nid, const NodePtr& n, std::vector* ret) { + const FoldChainEntry& e = bwd_chain[nid]; + if (e.kind == kMulConsumer && bwd_chain[e.source].kind == kProvider) { + const FoldChainEntry& se = bwd_chain[e.source]; + CHECK_EQ(n->num_outputs(), 1); + NodeEntry scale = ExpandBiasToMatchAxis( + se.scale_entry, + shape_vec[idx.entry_id(nid, 0)].ndim(), + shape_vec[idx.entry_id(se.scale_entry)].ndim(), + e.axis); + *ret = {MakeNode("broadcast_mul", n->attrs.name + "_sc", + {NodeEntry{n, 0, 0}, scale})}; + return true; + } else if (e.kind == kProvider) { + *ret = {n->inputs[e.fold_input_index]}; + return true; + } else { + return false; + } + }; + return GraphTransform(src, transform); +} + +NNVM_REGISTER_PASS(FoldScaleAxis) +.set_body(FoldScaleAxis); + +// property registration. +FoldScaleKind ReluScaleAxisBackward( + const NodeAttrs& attrs, + int axis, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector >* in_axis) { + in_axis->emplace_back(0, axis); + return kPassTroughFirst; +} + +NNVM_REGISTER_OP(relu) +.set_attr("FScaleAxisBackward", ReluScaleAxisBackward); + +NNVM_REGISTER_OP(leaky_relu) +.set_attr("FScaleAxisBackward", ReluScaleAxisBackward); + +FoldScaleKind BroadcastAddSubScaleAxisBackward( + const NodeAttrs& attrs, + int axis, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector >* in_axis) { + for (int i = 0; i < 2; ++i) { + std::pair m = MatchBroadcast1DAxis(out_shape[0], in_shape[i]); + if (m.second != -1 && in_shape[1 - i] == out_shape[0]) { + in_axis->emplace_back(i, axis); + in_axis->emplace_back(1 - i, m.second); + return kPassTroughFirst; + } + } + return kNone; +} + +NNVM_REGISTER_OP(broadcast_add) +.set_attr("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward); + +NNVM_REGISTER_OP(broadcast_sub) +.set_attr("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward); + +FoldScaleKind Conv2DScaleAxisBackward( + const NodeAttrs& attrs, + int axis, + const std::vector& in_shape, + const std::vector& out_shape, + std::vector >* in_axis) { + using top::Conv2DParam; + const Conv2DParam& param = nnvm::get(attrs.parsed); + // only optimize for nchw for now + if (param.layout == top::kNCHW) { + in_axis->emplace_back(1, 0); + if (param.use_bias) { + in_axis->emplace_back(2, 0); + } + return kMulConsumer; + } else { + return kNone; + } +} + +NNVM_REGISTER_OP(conv2d) +.set_attr("FScaleAxisBackward", Conv2DScaleAxisBackward); + +} // namespace compiler +} // namespace nnvm diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index b8c6a1adc02c..0f11fe6d46ac 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -16,6 +16,7 @@ #include #include "./compile_engine.h" #include "./graph_runtime.h" +#include "./pattern_util.h" namespace nnvm { namespace compiler { @@ -56,17 +57,10 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { // Reference counter of each op node // For now, always store result when an op is referred more than once. - std::vector ref_count(idx.num_nodes(), 0); - for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { - const auto& inode = idx[nid]; - if (inode.source->is_variable()) continue; - for (const auto& e : inode.inputs) { - ++ref_count[e.node_id]; - } - } + std::vector ref_count = GetNodeRefCounts(idx); for (const auto& e : idx.outputs()) { // this line will realize all the outputs - ref_count[e.node_id] += 2; + ref_count[e.node_id] += 1; } // Pattern for the subgraph std::vector pattern_vec(idx.num_nodes(), kOpaque); diff --git a/nnvm/src/compiler/graph_transform.h b/nnvm/src/compiler/graph_transform.h index 2099809115aa..dd80accbee3f 100644 --- a/nnvm/src/compiler/graph_transform.h +++ b/nnvm/src/compiler/graph_transform.h @@ -20,7 +20,7 @@ namespace compiler { * * \param graph The original graph * - * \param ftransform Function of (int nid, const Node* node, std::vector* out) -> bool + * \param ftransform Function of (int nid, const NodePtr& node, std::vector* out) -> bool * * If empty vector is returned, it means original entries should be kept. * @@ -36,7 +36,6 @@ Graph GraphTransform(Graph graph, FTransform ftransform) { // setup inputs and placeholder. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - if (inode.source->is_variable()) continue; bool need_copy = false; for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (updated[idx.entry_id(e)]) { @@ -57,7 +56,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) { if (!need_copy) { std::vector ret; - if (ftransform(nid, inode.source, &ret)) { + if (ftransform(nid, inode.weak_ref.lock(), &ret)) { CHECK_EQ(ret.size(), static_cast(inode.source->num_outputs())); for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { updated[idx.entry_id(nid, i)] = true; @@ -93,7 +92,7 @@ Graph GraphTransform(Graph graph, FTransform ftransform) { node->control_deps.push_back(selected_ptr); } std::vector ret; - if (ftransform(nid, node.get(), &ret)) { + if (ftransform(nid, node, &ret)) { CHECK_EQ(ret.size(), static_cast(inode.source->num_outputs())); for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { updated[idx.entry_id(nid, i)] = true; diff --git a/nnvm/src/compiler/pattern_util.h b/nnvm/src/compiler/pattern_util.h new file mode 100644 index 000000000000..10322a7dce85 --- /dev/null +++ b/nnvm/src/compiler/pattern_util.h @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file pattern_util.h + * \brief Utilities for doing various pattern matching in graph. +*/ +#ifndef NNVM_COMPILER_PATTERN_UTIL_H_ +#define NNVM_COMPILER_PATTERN_UTIL_H_ + +#include +#include +#include +#include + +namespace nnvm { +namespace compiler { + +/*! + * \brief find axis in oshape, such that: + * bias_shape = [1,1, ... oshape[axis], 1,1,] + * + * This is used to detect bias or scaling factor on channel dimension. + * \param oshape The output shape + * \param bias_shape The shape of bias or scaling factor. + * \return Pair of matched axis in o shape and bias_shape if found. + */ +inline std::pair MatchBroadcast1DAxis( + const TShape& oshape, const TShape& bias_shape) { + dim_t axis_dim = bias_shape.ndim(); + for (dim_t i = bias_shape.ndim(); i != 0; --i, --axis_dim) { + if (bias_shape[i - 1] != 1) break; + } + // everything is 1 + if (axis_dim == 0) { + return {oshape.ndim() - bias_shape.ndim(), 0}; + } + axis_dim = axis_dim - 1; + // The bias shape is not 1D + for (dim_t i = 0; i < axis_dim; ++i) { + if (bias_shape[i] != 1) return {-1, -1}; + } + int axis = static_cast( + oshape.ndim() - bias_shape.ndim() + axis_dim); + if (oshape[axis] != bias_shape[axis_dim]) return {-1, -1}; + return {axis, axis_dim}; +} + +/*! + * \brief Expand bias dimension to match needed axis. + * + * \param bias The bias NodeEntry + * \param out_dim output dimension. + * \param bias_dim The current bias dimension. + * \param axis The axis we want to match on. + */ +inline NodeEntry +ExpandBiasToMatchAxis(NodeEntry bias, + int out_dim, + int bias_dim, + int axis) { + if (bias_dim != 1) { + bias = MakeNode("squeeze", bias.node->attrs.name + "_sqz", {bias}); + } + int num_pad_axis = out_dim - axis - 1; + if (num_pad_axis > 0) { + std::unordered_map kwargs{ + {"axis", "1"}, + {"num_newaxis", std::to_string(num_pad_axis)}}; + return MakeNode("expand_dims", bias.node->attrs.name + "_expand", + {bias}, kwargs); + + } else { + return bias; + } +} + +/*! + * \brief Get the reference count of each node. + * \param idx The IndexedGraph + * \return ref_count vector of length number nodes. + */ +inline std::vector +GetNodeRefCounts(const IndexedGraph& idx) { + std::vector ref_count(idx.num_nodes(), 0); + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + for (const auto& e : inode.inputs) { + ++ref_count[e.node_id]; + } + } + for (const auto& e : idx.outputs()) { + // this line will realize all the outputs + ref_count[e.node_id] += 1; + } + return ref_count; +} +} // namespace compiler +} // namespace nnvm +#endif // NNVM_COMPILER_PATTERN_UTIL_H_ diff --git a/nnvm/src/compiler/simplify_inference.cc b/nnvm/src/compiler/simplify_inference.cc index 7dd9ade0ac96..141950b0545d 100644 --- a/nnvm/src/compiler/simplify_inference.cc +++ b/nnvm/src/compiler/simplify_inference.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2017 by Contributors - * \file simplify_batch_norm.cc + * \file simplify_inference.cc * \author Ziheng Jiang */ #include @@ -10,6 +10,7 @@ #include #include #include "./graph_transform.h" +#include "./pattern_util.h" namespace nnvm { namespace compiler { @@ -58,15 +59,9 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, shift = MakeNode( "elemwise_add", bn_name + "_add_beta", {shift, beta}); } - // use expand dims to pad lower dims to 1 - int num_pad_axis = static_cast(dshape.ndim() - param.axis) - 1; - if (num_pad_axis != 0) { - std::unordered_map kwargs{ - {"axis", std::to_string(param.axis)}, - {"num_newaxis", std::to_string(num_pad_axis)}}; - scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs); - shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs); - } + int axis = param.axis; + scale = ExpandBiasToMatchAxis(scale, dshape.ndim(), 1, axis); + shift = ExpandBiasToMatchAxis(shift, dshape.ndim(), 1, axis); NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", {data, scale}); out = MakeNode("broadcast_add", bn_name + "_out", @@ -80,7 +75,7 @@ Graph SimplifyInference(nnvm::Graph src) { // Get attributes from the graph const IndexedGraph& idx = src.indexed_graph(); const ShapeVector& shape_vec = src.GetAttr("shape"); - auto transform = [&](uint32_t nid, const Node* n, std::vector* ret) { + auto transform = [&](uint32_t nid, const NodePtr& n, std::vector* ret) { if (n->is_variable()) return false; static const Op* bn_op = Op::Get("batch_norm"); static const Op* dropout_op = Op::Get("dropout"); diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index d1b6efb66dbb..62c7085c1210 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -28,6 +28,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { // nodes_ IndexedGraph::Node new_node; new_node.source = n.get(); + new_node.weak_ref = n; nodes_.emplace_back(std::move(new_node)); // arg_nodes_ if (n->is_variable()) { diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 12ae06731145..c3ca3681cba9 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -460,21 +460,21 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, std::vector oshape; if (param.axis.ndim() == 0) { for (dim_t i = 0; i < shp.ndim(); ++i) { - if(shp[i] != 1) { + if (shp[i] != 1) { oshape.emplace_back(shp[i]); } } } else { std::unordered_set axis_checker; for (size_t i = 0; i < param.axis.ndim(); ++i) { - if(param.axis[i] < 0) { + if (param.axis[i] < 0) { int real_axis = param.axis[i] + static_cast(shp.ndim()); CHECK(real_axis < static_cast(shp.ndim()) && real_axis >= 0); axis_checker.insert(real_axis); } } for (size_t i = 0; i < shp.ndim(); ++i) { - if(axis_checker.find(i) == axis_checker.end()) { + if (axis_checker.find(i) == axis_checker.end()) { oshape.emplace_back(shp[i]); } else { CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!" @@ -483,7 +483,7 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, } } } - if(oshape.size() == 0) { + if (oshape.size() == 0) { // Handles the case where all axes are squeezed. oshape.push_back(1); } diff --git a/nnvm/tests/python/compiler/test_fold_axis.py b/nnvm/tests/python/compiler/test_fold_axis.py new file mode 100644 index 000000000000..f306b3a222aa --- /dev/null +++ b/nnvm/tests/python/compiler/test_fold_axis.py @@ -0,0 +1,49 @@ +"""Unittest cases for fold_axis""" +import nnvm +from nnvm import symbol as sym +from nnvm.compiler import graph_util, graph_attr + +def test_fold_axis_conv(): + def before(x, conv_weight, conv_bias, scale, channels): + y = sym.conv2d(x, conv_weight, conv_bias, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1), + name="conv") + y = sym.relu(y) + y = y * sym.expand_dims(scale, axis=1, num_newaxis=2) + return y + + def expected(x, conv_weight, conv_bias, scale, channels): + conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3) + conv_bias = conv_bias * scale + y = sym.conv2d(x, + conv_weight, + conv_bias, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1), + name="conv") + y = sym.relu(y) + return y + + # Before simplify + def check(shape, channels): + x = sym.Variable("x") + 1 + weight = sym.Variable("weight") + bias = sym.Variable("bias") + scale = sym.Variable("scale") + y1 = before(x, weight, bias, scale, channels) + y2 = expected(x, weight, bias, scale, channels) + ishape = {"x": shape, "scale": (channels,)} + g1 = nnvm.graph.create(y1) + g2 = nnvm.graph.create(y2) + graph_attr.set_shape_inputs(g1, ishape) + g1 = g1.apply("InferShape").apply("FoldScaleAxis") + # assert graph equals as expected + graph_util.check_graph_equal(g1, g2) + + check((2, 4, 10, 10), 2) + +if __name__ == "__main__": + test_fold_axis_conv() diff --git a/nnvm/tests/python/compiler/test_simplify_inference.py b/nnvm/tests/python/compiler/test_simplify_inference.py index 5da46608081a..e2826765995e 100644 --- a/nnvm/tests/python/compiler/test_simplify_inference.py +++ b/nnvm/tests/python/compiler/test_simplify_inference.py @@ -14,8 +14,8 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var, # for 2D num_newaxis=len(shape) - axis - 1 if num_newaxis: - scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis) - shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis) + scale = sym.expand_dims(scale, axis=1, num_newaxis=num_newaxis) + shift = sym.expand_dims(shift, axis=1, num_newaxis=num_newaxis) return x * scale + shift @@ -39,8 +39,6 @@ def check(dim, axis, nstep): g2 = nnvm.graph.create(y2) graph_attr.set_shape_inputs(g, ishape) g1 = g.apply("InferShape").apply("SimplifyInference") - # Some prints for debug - # print(g1.ir()) # assert graph equals as expected graph_util.check_graph_equal(g1, g2) diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index d19aa5664fdf..ca7a0156a8b5 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -28,7 +28,8 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): new_sym, params = frontend.from_mxnet(symbol, args, auxs) dshape = x.shape shape_dict = {'data': dshape} - graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params) + with nnvm.compiler.build_config(opt_level=3): + graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs m.set_input("data", tvm.nd.array(x.astype(dtype)))