From 2a871f35acb0ae31bccf6747073603681c8044ff Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 17 Jan 2019 09:29:12 +0800 Subject: [PATCH] [RELAY][PASS] Support Negative Scale in FoldScaleAxis (#2426) * [RELAY][PASS] Support Negative Scale in FoldScaleAxis * Fix comment --- src/relay/pass/fold_scale_axis.cc | 328 ++++++++++-------- .../python/relay/test_pass_fold_scale_axis.py | 95 ++++- 2 files changed, 281 insertions(+), 142 deletions(-) diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 60df5d90a87c..0cd46ff330e1 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -59,6 +59,36 @@ using runtime::TypedPackedFunc; */ using AxesSet = Array; +class Message; + +/*! + * \brief Message propogated during the prepare phase. + */ +class MessageNode : public RelayNode { + public: + /*! \brief Axes for scaling */ + AxesSet axes; + /*! + * \brief Whether folding requires the scale to be positive constant. This is necessary if some + * operators (e.g. Relu) is present. + */ + bool require_positive; + + static Message make(const AxesSet& axes, bool require_positive); + + static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; + TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef); + +Message MessageNode::make(const AxesSet& axes, bool require_positive) { + auto n = make_node(); + n->axes = axes; + n->require_positive = require_positive; + return Message(n); +} + /*! * \brief Merge two axis set together by taking * intersection. @@ -88,14 +118,29 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { return ret; } +/*! + * \brief Merge two messages together by taking intersection. + * + * \param lhs The lhs message. + * \param rhs The rhs message. + * \return The result of intersection. + */ +Message Intersect(const Message& lhs, const Message& rhs) { + if (!lhs.defined()) return lhs; + if (!rhs.defined()) return rhs; + auto axes = Intersect(lhs->axes, rhs->axes); + return MessageNode::make(axes, lhs->require_positive || rhs->require_positive); +} + /*! * \brief Preparation function for pass scale forward. * \param call The call node. - * \param out_scale_axes Possible scaling on axes of the output. - * \return The result scaling on axes of the input. + * \param out_message Message from the output containing possible scaling on axes and whether + * positive scale is required. + * \return The message containing the result scaling on axes of the input. */ using FForwardPrep = runtime::TypedPackedFunc< - Array (const Call& call, const AxesSet& out_scale_axes)>; + Array (const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { @@ -126,16 +171,16 @@ class ScaledExprNode : public TempExprNode { using FForwardRewrite = TypedPackedFunc< Expr(const Call& ref_call, const Array& new_args, - const AxesSet& expeced_out_axes)>; + const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map + std::unordered_map Prepare(const Expr& body) { - this->Update(body, NullValue()); + this->Update(body, NullValue()); this->VisitExpr(body); // flist is added in the Post-DFS order // which is a special case of topological order. @@ -152,9 +197,9 @@ class ForwardPrep : private ExprVisitor { // The invoke list std::vector > flist_; // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // Update the message stored at node. - void Update(const Expr& node, const AxesSet& axes) { + void Update(const Expr& node, const Message& message) { // We run intersection of messages: // // %y = multiply(%x, %scale) @@ -167,9 +212,9 @@ class ForwardPrep : private ExprVisitor { // and the forward folding won't be triggered. const Node* key = node.get(); if (message_.count(key)) { - message_[key] = Intersect(message_[key], axes); + message_[key] = Intersect(message_[key], message); } else { - message_[key] = axes; + message_[key] = message; } } // Visitor pattern override. @@ -180,7 +225,7 @@ class ForwardPrep : private ExprVisitor { void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); auto flazy = [this, op] { - this->Update(op->body, NullValue()); + this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } @@ -193,23 +238,23 @@ class ForwardPrep : private ExprVisitor { Op::GetAttr("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); - AxesSet out_axes; + Message out_message; if (it != message_.end()) { - out_axes = it->second; + out_message = it->second; } else { - out_axes = NullValue(); + out_message = NullValue(); } // pass the message back to all the children it references. auto f = fprep.get(call->op, nullptr); if (f != nullptr) { - Array in_axes = f(GetRef(call), out_axes); - CHECK_EQ(in_axes.size(), call->args.size()); + Array in_messages = f(GetRef(call), out_message); + CHECK_EQ(in_messages.size(), call->args.size()); for (size_t i = 0; i < call->args.size(); ++i) { - this->Update(call->args[i], in_axes[i]); + this->Update(call->args[i], in_messages[i]); } } else { for (size_t i = 0; i < call->args.size(); ++i) { - this->Update(call->args[i], NullValue()); + this->Update(call->args[i], NullValue()); } } }; @@ -221,7 +266,7 @@ class ForwardPrep : private ExprVisitor { // do not support pass scale through tuple for now. auto flazy = [this, op]() { for (const Expr& field : op->fields) { - this->Update(field, NullValue()); + this->Update(field, NullValue()); } }; flist_.push_back(flazy); @@ -230,13 +275,13 @@ class ForwardPrep : private ExprVisitor { void VisitExpr_(const IfNode* op) { ExprVisitor::VisitExpr_(op); // do pass through condition - // by assigning NullValue + // by assigning NullValue // it means fuse signal cannot pass // through into these subexpressions. auto flazy = [this, op]() { - this->Update(op->cond, NullValue()); - this->Update(op->true_branch, NullValue()); - this->Update(op->false_branch, NullValue()); + this->Update(op->cond, NullValue()); + this->Update(op->true_branch, NullValue()); + this->Update(op->false_branch, NullValue()); }; flist_.push_back(flazy); } @@ -247,13 +292,16 @@ class ForwardPrep : private ExprVisitor { //---------------------------------------------- // Intermediate operators -Array ReluForwardPrep(const Call& call, AxesSet out) { - return {out}; +Array ReluForwardPrep(const Call& call, const Message& out_message) { + if (out_message.defined()) { + return {MessageNode::make(out_message->axes, true)}; + } + return {out_message}; } Expr ReluForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_axes) { + const Message& message) { const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d @@ -278,23 +326,23 @@ RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub -Array AddSubForwardPrep(const Call& call, AxesSet out_axes) { +Array AddSubForwardPrep(const Call& call, const Message& out_message) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - - auto none = NullValue(); - if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) { - return {out_axes, none}; - } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) { - return {none, out_axes}; - } else { - return {none, none}; + auto none = NullValue(); + if (out_message.defined()) { + if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) { + return {out_message, none}; + } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) { + return {none, out_message}; + } } + return {none, none}; } Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_out_axes) { + const Message& message) { const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); if (!slhs && !srhs) return Expr(); @@ -342,9 +390,10 @@ RELAY_REGISTER_OP("subtract") // Multiply produces the scale-axis pair. Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_out_axes) { - if (!expected_out_axes.defined()) return Expr(); - if (expected_out_axes.size() == 0) return Expr(); + const Message& message) { + if (!message.defined()) return Expr(); + const auto& expected_out_axes = message->axes; + CHECK(expected_out_axes.defined() && expected_out_axes.size()); // TODO(tvm-team) allow same axes accumulation // not as important because it is less common in nn. const auto* slhs = new_args[0].as(); @@ -356,14 +405,15 @@ Expr MultiplyForwardRewrite(const Call& ref_call, Expr lhs = new_args[0]; Expr rhs = new_args[1]; auto rnode = make_node(); + if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && - IsAllPositiveConstant(rhs)) { + (!message->require_positive || IsAllPositiveConstant(rhs))) { rnode->value = lhs; rnode->scale = rhs; rnode->axes = expected_out_axes; return Expr(rnode); } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && - IsAllPositiveConstant(lhs)) { + (!message->require_positive || IsAllPositiveConstant(lhs))) { rnode->value = rhs; rnode->scale = lhs; rnode->axes = expected_out_axes; @@ -378,7 +428,7 @@ RELAY_REGISTER_OP("multiply") // Consumer operators // Conv2D send out requirement of axis folding. -Array Conv2DForwardPrep(const Call& call, AxesSet out) { +Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // TODO(tvm-team) support general data layout // by transforming weight const auto* param = call->attrs.as(); @@ -389,6 +439,7 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { int c_small_axis = data_layout.Indexof('c'); CHECK_GE(c_big_axis, 0); + Message none = NullValue(); AxesSet data_axes = NullValue(); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. @@ -403,13 +454,16 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { (param->groups == 1 || is_depthwise_conv2d)) { data_axes = {c_big_axis}; } - return {data_axes, NullValue()}; + if (data_axes.defined()) { + return {MessageNode::make(data_axes, false), none}; + } + return {none, none}; } // Conv2D consumes the scale axis during transformation. Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_axes) { + const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as(); const auto* sweight = new_args[1].as(); @@ -458,11 +512,10 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(Expr data) { - auto expected_scale_axes = - ForwardPrep().Prepare(data); + auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ - auto it = expected_scale_axes.find(call.get()); - if (it != expected_scale_axes.end()) { + auto it = message.find(call.get()); + if (it != message.end()) { return it->second; } else { return NodeRef(nullptr); @@ -484,15 +537,16 @@ class BackwardTransformer; /*! * \brief Preparation function for for pass scale backward. * \param call The call node. - * \param in_scale_axes Allowed input scaling. - * \return The result scaling on axes of the input. + * \param in_messages Messages from the input containing allowed input scaling and whether + * positive scale is required. + * \return Message containing the result scaling on axes of the input. */ using FBackwardPrep = TypedPackedFunc< - AxesSet(const Call& call, const Array& in_scale_axes)>; + Message(const Call& call, const Array& in_messages)>; using FBackwardTransform = TypedPackedFunc< Expr(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer)>; @@ -503,7 +557,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); @@ -512,7 +566,7 @@ class BackwardPrep : private ExprVisitor { private: // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // reference counter of an internal expr std::unordered_map ref_counter_; // Visit the expression. @@ -527,18 +581,18 @@ class BackwardPrep : private ExprVisitor { // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; - Array in_axes; + Array in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); if (it != message_.end()) { - in_axes.push_back(it->second); + in_messages.push_back(it->second); } else { - in_axes.push_back(NullValue()); + in_messages.push_back(NullValue()); } } - AxesSet out_axes = f(GetRef(call), in_axes); - if (out_axes.defined()) { - message_[call] = out_axes; + Message out_message = f(GetRef(call), in_messages); + if (out_message.defined()) { + message_[call] = out_message; } } }; @@ -549,7 +603,7 @@ class BackwardTransformerNode : public: // Run forward transform. Expr Fold(Expr expr) { - expected_scale_axes_ = BackwardPrep().Prepare(expr); + message_ = BackwardPrep().Prepare(expr); return this->Mutate(expr); } /*! @@ -560,12 +614,12 @@ class BackwardTransformerNode : * \param scale The scale applied to the axes. * \return The result of transformation. */ - Expr Transform(const Expr& expr, AxesSet axes, Expr scale) { + Expr Transform(const Expr& expr, Message message, Expr scale) { // NOTE: the result of Transform is memoized. if (const CallNode* call_node = expr.as()) { - return Transform(call_node, axes, scale); + return Transform(call_node, message, scale); } else { - CHECK(!axes.defined()) << "outstanding scale"; + CHECK(!message.defined()) << "outstanding scale"; return ExprMutator::VisitExpr(expr); } } @@ -585,14 +639,14 @@ class BackwardTransformerNode : return new_expr; } /*! - * \brief Get the expected axes on expr. + * \brief Get the message propogated to the expr. * \param expr The expresison. - * \return The expected axes. + * \return The message containing the expected axes and whether positive scale is required. */ - AxesSet GetExpectedAxes(const Expr& expr) const { - auto it = expected_scale_axes_.find(expr.get()); - if (it != expected_scale_axes_.end()) return it->second; - return NullValue(); + Message GetMessage(const Expr& expr) const { + auto it = message_.find(expr.get()); + if (it != message_.end()) return it->second; + return NullValue(); } // solver is not serializable. @@ -603,13 +657,13 @@ class BackwardTransformerNode : private: // Valid axes on each node. - std::unordered_map expected_scale_axes_; + std::unordered_map message_; // Override mutation of call. Expr VisitExpr_(const CallNode* call_node) final { - return Transform(call_node, NullValue(), NullValue()); + return Transform(call_node, NullValue(), NullValue()); } // Transform of CallNode. - Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale); + Expr Transform(const CallNode* call_node, Message message, Expr scale); }; class BackwardTransformer : public NodeRef { @@ -625,7 +679,7 @@ class BackwardTransformer : public NodeRef { }; Expr BackwardTransformerNode::Transform( - const CallNode* call_node, AxesSet axes, Expr scale) { + const CallNode* call_node, Message message, Expr scale) { static const auto& ftransform = Op::GetAttr("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); @@ -636,13 +690,13 @@ Expr BackwardTransformerNode::Transform( return it->second; } Expr new_expr = f(GetRef(call_node), - axes, + message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { - CHECK(!axes.defined()) << "outstanding scale"; + CHECK(!message.defined()) << "outstanding scale"; return NormalCallTransform(call_node); } } @@ -653,19 +707,22 @@ Expr BackwardTransformerNode::Transform( //---------------------------------------------- // Intermediate operators -AxesSet ReluBackwardPrep(const Call& call, const Array& in_axes) { - return in_axes[0]; +Message ReluBackwardPrep(const Call& call, const Array& in_messages) { + if (in_messages[0].defined()) { + return MessageNode::make(in_messages[0]->axes, true); + } + return in_messages[0]; } Expr ReluBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { - if (!axes.defined()) { + if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } Expr input = transformer->Transform( - call->args[0], axes, scale); + call->args[0], message, scale); return CallNode::make(call->op, {input}, call->attrs, call->type_args); } @@ -682,64 +739,63 @@ RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub -AxesSet AddSubBackwardPrep(const Call& call, const Array& in_axes) { +Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); AttrsEqual equal; - if (in_axes[0].defined() && - MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) { - return in_axes[0]; - } else if (in_axes[1].defined() && - MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) { - return in_axes[1]; - } else if (in_axes[0].defined() && - in_axes[1].defined() && - equal(in_axes[0], in_axes[1]) && + if (in_messages[0].defined() && + MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { + return in_messages[0]; + } else if (in_messages[1].defined() && + MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { + return in_messages[1]; + } else if (in_messages[0].defined() && + in_messages[1].defined() && + equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. - return in_axes[0]; + return in_messages[0]; } else { - auto res = NullValue(); - CHECK(!res.defined()); + auto res = NullValue(); return res; } } Expr AddSubBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - if (!axes.defined()) { + if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } - AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); - AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); + Message lhs_message = transformer->GetMessage(call->args[0]); + Message rhs_message = transformer->GetMessage(call->args[1]); AttrsEqual equal; - if (lhs_axes.defined() && rhs_axes.defined()) { - CHECK(equal(lhs_axes, rhs_axes)); - CHECK(equal(axes, lhs_axes)); - Expr lhs = transformer->Transform(call->args[0], axes, scale); - Expr rhs = transformer->Transform(call->args[1], axes, scale); + if (lhs_message.defined() && rhs_message.defined()) { + CHECK(equal(lhs_message->axes, rhs_message->axes)); + CHECK(equal(message->axes, lhs_message->axes)); + Expr lhs = transformer->Transform(call->args[0], message, scale); + Expr rhs = transformer->Transform(call->args[1], message, scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); - } else if (lhs_axes.defined()) { - CHECK(equal(axes, lhs_axes)); - Expr lhs = transformer->Transform(call->args[0], axes, scale); + } else if (lhs_message.defined()) { + CHECK(equal(message->axes, lhs_message->axes)); + Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform( - call->args[1], NullValue(), NullValue()); + call->args[1], NullValue(), NullValue()); Expr rhs_scale = ExpandBiasToMatchAxis( - scale, tlhs->shape.size(), axes); + scale, tlhs->shape.size(), message->axes); rhs = Multiply(rhs, rhs_scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); - } else if (rhs_axes.defined()) { - CHECK(equal(axes, rhs_axes)); + } else if (rhs_message.defined()) { + CHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform( - call->args[0], NullValue(), NullValue()); - Expr rhs = transformer->Transform(call->args[1], axes, scale); + call->args[0], NullValue(), NullValue()); + Expr rhs = transformer->Transform(call->args[1], message, scale); Expr lhs_scale = ExpandBiasToMatchAxis( - scale, trhs->shape.size(), axes); + scale, trhs->shape.size(), message->axes); lhs = Multiply(lhs, lhs_scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { @@ -763,29 +819,29 @@ RELAY_REGISTER_OP("subtract") // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { - CHECK(!axes.defined()) << "outstanding scale"; + CHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); - AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); - if (lhs_axes.defined() && lhs_axes.size() != 0) { + Message lhs_message = transformer->GetMessage(call->args[0]); + Message rhs_message = transformer->GetMessage(call->args[1]); + if (lhs_message.defined()) { + CHECK(lhs_message->axes.defined() && lhs_message->axes.size()); // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; - // Only propagate positive scaling. - if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) && - IsAllPositiveConstant(rhs)) { - return transformer->Transform(call->args[0], lhs_axes, rhs); + if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) && + (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) { + return transformer->Transform(call->args[0], lhs_message, rhs); } - } else if (rhs_axes.defined() && rhs_axes.size() != 0) { - // Only propagate positive scaling. + } else if (rhs_message.defined()) { + CHECK(rhs_message->axes.defined() && rhs_message->axes.size()); Expr lhs = call->args[0]; - if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) && - IsAllPositiveConstant(lhs)) { - return transformer->Transform(call->args[1], rhs_axes, lhs); + if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) && + (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) { + return transformer->Transform(call->args[1], rhs_message, lhs); } } return transformer->NormalCallTransform(call.operator->()); @@ -796,7 +852,7 @@ RELAY_REGISTER_OP("multiply") // Consumer operators // Conv2D send out requirement of axis folding. -AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { +Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) { const auto* param = call->attrs.as(); CHECK(param != nullptr); Layout kernel_layout(param->kernel_layout); @@ -817,18 +873,18 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { kernel_layout.Indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { - return {c_big_axis}; + return MessageNode::make({c_big_axis}, false); } else { - return NullValue(); + return NullValue(); } } // Conv2D consumes the scale axis during transformation. Expr Conv2DBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { - if (!axes.defined()) { + if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } const auto* param = call->attrs.as(); @@ -841,8 +897,8 @@ Expr Conv2DBackwardTransform(const Call& call, // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.Indexof('i'), -1); - CHECK(axes.size() == 1 && - c_big_axis == axes[0]->value); + CHECK(message->axes.size() == 1 && + c_big_axis == message->axes[0]->value); int big_oc_axis = kernel_layout.Indexof('O'); // Check it must be depthwise or full conv2d. @@ -850,9 +906,9 @@ Expr Conv2DBackwardTransform(const Call& call, CHECK(param->groups == 1 || is_depthwise_conv2d); Expr data = transformer->Transform( - call->args[0], NullValue(), NullValue()); + call->args[0], NullValue(), NullValue()); Expr weight = transformer->Transform( - call->args[1], NullValue(), NullValue()); + call->args[1], NullValue(), NullValue()); // scale on input for deptwise. Expr wscale = ExpandBiasToMatchAxis( scale, kernel_layout.ndim(), {big_oc_axis}); diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 7d0089cfb3c4..dd9e7522fecf 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -174,7 +174,6 @@ def check(shape, channels, in_scale): assert in_channels == channels weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.var("in_scale", shape=(in_channels,)) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) @@ -182,10 +181,52 @@ def check(shape, channels, in_scale): in_scale = relay.var("in_scale", shape=(4,)) check((2, 11, 10, 4), 4, in_scale) - in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32") + in_scale = relay.const(-_get_positive_scale((4,))) check((2, 11, 10, 4), 4, in_scale) +def test_fold_fwd_negative_scale(): + """Testcase of folding negative scale""" + def before(x, conv_weight, in_scale, channels): + args = [x, conv_weight] + x = relay.multiply(x, in_scale) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def expected(x, conv_weight, in_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight] + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + y = relay.nn.conv2d(x, + conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[1] + in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) + weight = relay.var("weight") + y1 = before(x, weight, in_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + y1_expected = expected(x, weight, in_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 4) + + def test_fold_bwd_simple(): """Simple testcase.""" def before(x, conv_weight, out_bias, out_scale, channels): @@ -223,7 +264,7 @@ def check(shape, channels): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -283,7 +324,7 @@ def check(shape, channels): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -356,7 +397,7 @@ def check(shape, channels): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels,1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -411,7 +452,7 @@ def check(shape, channels, fbefore): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) @@ -448,13 +489,55 @@ def check(shape, channels, out_scale): check((4, 4, 10, 10), 4, out_scale) +def test_fold_bwd_negative_scale(): + """Testcase of folding negative scale""" + def before(x, conv_weight, out_scale, channels): + args = [x, conv_weight] + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.multiply(y, out_scale) + return relay.Function(args, y) + + def expected(x, conv_weight, out_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight] + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + weight = relay.var("weight") + out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) + y1 = before(x, weight, out_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_expected = expected(x, weight, out_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 8) + + if __name__ == "__main__": test_fold_fwd_simple() test_fold_fwd_dual_path() test_fold_fwd_fail() test_fold_fwd_relu_fail() + test_fold_fwd_negative_scale() test_fold_bwd_simple() test_fold_bwd_dual_path() test_fold_bwd_dual_consumer() test_fold_bwd_fail() test_fold_bwd_relu_fail() + test_fold_bwd_negative_scale()