diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index dfad1013701f9..cb87d358e966b 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode { /*! \brief Attributes used in squeeze operators */ struct SqueezeAttrs : public tvm::AttrsNode { - Array axes; + // use axis to make the name numpy compatible. + Array axis; TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("The axes to squeeze in the input tensor." - "If `axes = []`, all axis of dimension 1 get squeezed;" + TVM_ATTR_FIELD(axis) + .describe("The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" "Else, the dimension in axes get squeezed." - "It is an error if an axes does not has dimension 1.") - .set_default(Array({})); + "It is an error if an axis does not has dimension 1.") + .set_default(NullValue >()); } }; // struct SqueezeAttrs diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 50e5dfa8d89be..2e3bbadb7841b 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -40,6 +40,18 @@ class ExprNode : public RelayNode { "field for this node"; return this->checked_type_; } + /*! + * \brief Check if the inferred(checked) type of the Expr + * is backed by a TTypeNode and return it. + * + * \note This function will thrown an error if the node type + * of this Expr is not TTypeNode. + * + * \return The corresponding TTypeNode pointer. + * \tparam The specific TypeNode we look for. + */ + template + inline const TTypeNode* type_as() const; static constexpr const char* _type_key = "relay.Expr"; TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); @@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode { RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); +// implementataions +template +inline const TTypeNode* ExprNode::type_as() const { + static_assert(std::is_base_of::value, + "TType must be a special case of type"); + CHECK(checked_type_.defined()) + << "Type inference for this Expr has not completed"; + const TTypeNode* node = checked_type_.as(); + CHECK(node != nullptr) + << "Expected type to be " << TTypeNode::_type_key + << ", but get " << checked_type_->type_key(); + return node; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index c0256cf3a1c37..bf4025f79224c 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -150,7 +150,14 @@ class ExprVisitor class ExprMutator : public ::tvm::relay::ExprFunctor { public: - Expr Mutate(const Expr& expr); + /*! + * \brief Mutate is alias for VisitExpr + * \return expr. + */ + Expr Mutate(const Expr& expr) { + return this->VisitExpr(expr); + } + Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const ConstantNode* op) override; Expr VisitExpr_(const GlobalVarNode* op) override; @@ -161,7 +168,8 @@ class ExprMutator Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; - /*! \brief Used to visit the types inside of expressions. + /*! + * \brief Used to visit the types inside of expressions. * * Can be overloaded to transform the types in arbitrary * ways, one way would be to define a sub-class of type @@ -169,7 +177,7 @@ class ExprMutator */ virtual Type VisitType(const Type& t); - private: + protected: /*! \brief Internal map used for memoization. */ std::unordered_map memo_; }; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 9f28fbebccfcc..ad447ad13cee4 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -74,6 +74,17 @@ class OpNode : public relay::ExprNode { v->Visit("support_level", &support_level); } + /*! + * \brief Check that if current op is a "primtive operator". + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + */ + bool IsPrimitiveOp() const { + if (is_primitive_ != -1) return is_primitive_ != 0; + is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; + return is_primitive_ != 0; + } + static constexpr const char* _type_key = "relay.Op"; TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); @@ -81,9 +92,24 @@ class OpNode : public relay::ExprNode { // friend class friend class GenericOpMap; friend class OpRegistry; + friend bool IsPrimitiveOp(const Expr&); // Program internal unique index of operator. // Used to help index the program. uint32_t index_{0}; + // whether this is a primitive op. -1 means unknown. + mutable int is_primitive_{-1}; + // Internal function to compute if it is primitive op + bool IsPrimitiveOp_() const { + const auto& fn_ty = this->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + return true; + } }; /*! @@ -497,22 +523,7 @@ inline ValueType OpMap::get(const Op& op, */ inline bool IsPrimitiveOp(const Expr& expr) { const auto* op = expr.as(); - - if (!op) { - return false; - } - - const auto& fn_ty = op->op_type; - if (fn_ty->type_constraints.size() != 1) return false; - - const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); - if (rel == nullptr) return false; - // validate if the type parameter matches up - for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { - if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; - } - - return true; + return op != nullptr && op->IsPrimitiveOp(); } } // namespace relay diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index f930751c41a7d..6adfaacdc86d3 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -10,6 +10,7 @@ from .expr import Expr from .ty import Type + def infer_type(expr, env=None): """Infer the type of expr under the context of env. @@ -30,6 +31,23 @@ def infer_type(expr, env=None): return _ir_pass.infer_type(expr, env) +def forward_fold_scale_axis(expr): + """Fold the scaling of axis into weights of conv2d/dense. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression, we expect that expr's types + should be fully inferred by infer_type. + + Returns + ------- + folded_expr : tvm.relay.Expr + The folded expression after transformation. + """ + return _ir_pass.forward_fold_scale_axis(expr) + + def well_formed(expr): """Check that each Var is only bound once (well formed). @@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs): """ return bool(_make._alpha_equal(lhs, rhs)) + def graph_equal(lhs, rhs): """Compare two Relay expr for data-flow equivalence. The difference between this and alpha-equality is that @@ -170,6 +189,7 @@ def graph_equal(lhs, rhs): """ return bool(_make._graph_equal(lhs, rhs)) + def structural_hash(value): """Hash a Relay expression structurally. diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9d14463a530c7..909b175f08ca6 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -49,27 +49,25 @@ def transpose(data, axes=None): return _make.transpose(data, list(axes)) -def squeeze(data, axes=None): +def squeeze(data, axis=None): """Squeeze axes in the array. Parameters ---------- - data : relay.Expr + data : tvm.relay.Expr The input data to the operator. - axes : None or List[int] - Axes to remove. - If axes = [] or = None, remove all axis of dimensions 1. - Otherwise, remove all axis in axes. - If any axis in axes has dimension that does not equal 1, it is an error. + axis : None or List[int] + The set of axes to remove. + If axis = None, remove all axis of dimensions 1. + If any specified axis has dimension that does not equal 1, it is an error. Returns ------- - result : relay.Expr + result : tvm.relay.Expr The squeezed result. """ - axes = axes or [] - return _make.squeeze(data, list(axes)) + return _make.squeeze(data, axis) def reshape(data, newshape): diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 7aab9bb3223b5..8409581b53bf4 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -296,13 +296,23 @@ class AlphaEqualHandler: if (const CallNode* rhs = other.as()) { if (!ExprEqual(lhs->op, rhs->op)) return false; if (lhs->args.size() != rhs->args.size()) return false; - if (lhs->type_args.size() != rhs->type_args.size()) return false; - + // skip type_args check for primitive ops. + bool is_primitive = IsPrimitiveOp(lhs->op); + if (!is_primitive) { + if (lhs->type_args.size() != rhs->type_args.size()) { + return false; + } + } for (size_t i = 0; i < lhs->args.size(); ++i) { - if (!ExprEqual(lhs->args[i], rhs->args[i])) return false; + if (!ExprEqual(lhs->args[i], rhs->args[i])) { + return false; + } } - for (size_t i = 0; i < lhs->type_args.size(); ++i) { - if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; + + if (!is_primitive) { + for (size_t i = 0; i < lhs->type_args.size(); ++i) { + if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false; + } } return AttrEqual(lhs->attrs, rhs->attrs); } else { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 557daa98e8998..b7a752d43a5c3 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -12,12 +12,12 @@ namespace tvm { namespace relay { -Expr ExprMutator::Mutate(const Expr& expr) { +Expr ExprMutator::VisitExpr(const Expr& expr) { auto it = this->memo_.find(expr); if (it != this->memo_.end()) { return it->second; } else { - Expr new_expr = ExprMutator::VisitExpr(expr); + Expr new_expr = ExprFunctor::VisitExpr(expr); memo_[expr] = new_expr; return new_expr; } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5faa0805426a8..635f04668f331 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -761,9 +761,9 @@ Examples:: TVM_REGISTER_NODE_TYPE(SqueezeAttrs); Expr MakeSqueeze(Expr data, - Array axes) { + Array axis) { auto attrs = make_node(); - attrs->axes = std::move(axes); + attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); return CallNode::make(op, {data}, Attrs(attrs), {}); } @@ -785,8 +785,8 @@ bool SqueezeRel(const Array& types, const auto* param = attrs.as(); CHECK(param != nullptr); std::vector result_shape; - // if axes is empty, squeeze all axes of dimension 1 - if (param->axes.size() == 0) { + // if axes is None, squeeze all axes of dimension 1 + if (!param->axis.defined()) { for (const auto& e : data->shape) { const int64_t* axis_ptr = as_const_int(e); CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; @@ -800,10 +800,8 @@ bool SqueezeRel(const Array& types, for (const auto& e : data->shape) { original_shape.push_back(std::pair(e, true)); } - for (const auto& e : param->axes) { - const int64_t* axis_ptr = as_const_int(e); - CHECK(axis_ptr != nullptr); - original_shape.at(*axis_ptr).second = false; + for (const auto& e : param->axis) { + original_shape.at(e->value).second = false; } for (const auto p : original_shape) { if (p.second) { diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc new file mode 100644 index 0000000000000..b1c767704372e --- /dev/null +++ b/src/relay/pass/fold_scale_axis.cc @@ -0,0 +1,554 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file fold_scale_axis.cc + * + * \brief Fold axis scaling into weights of + * conv/dense operators. + */ +#include +#include +#include +#include "pattern_util.h" +#include "../op/nn/layout.h" + +namespace tvm { +namespace relay { +/*! + * \brief namespace of fold scale axis + * + * Use namespace to reduce potential naming conflict. + */ +namespace fold_scale_axis { + +using runtime::TypedPackedFunc; + + +// FoldScaleAxisFoward algorithm: +// +// The general idea is that we transform Expr to tuple of +// (value, axes, scale), where the final result satiesfies: +// +// result = value +// for i, k in enumerate(axes): +// k-ith dimension of result *= i-th dimension of scale +// +// Then we can propagate this signal along and fold the scale if necessary. +// However, it is possible that certain scale may never be consumed +// if there is no dense/conv2d that follows multiplication. +// +// In order to make sure all the scale we sent out can be consumed eventually, +// we run a backward "preparation phase", which propagates the demand +// of the potential axes scaling back to its input. +// +// The folding process is done in two steps: +// - Prepare phase: backward propagation of demand. +// - Transform phase: forward transformation, + +/*! + * \brief sorted array axis, can also be nullptr. + * + * nullptr means no scaling request can be done. + */ +using AxesSet = Array; + +/*! + * \brief Merge two axis set together by taking + * intersection. + * + * \note The axes in a AxesSet should be sorted. + * + * \param lhs The left axis. + * \param rhs The right axis. + * \return The result of the inersection. + */ +AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { + if (!lhs.defined()) return lhs; + if (!rhs.defined()) return rhs; + // This code relies on axes in a AxesSet to be sorted. + AxesSet ret; + size_t i = 0, j = 0; + while (i < lhs.size() && j < rhs.size()) { + if (lhs[i]->value < rhs[j]->value) { + ++i; + } else if (lhs[i]->value > rhs[j]->value) { + ++j; + } else { + ret.push_back(lhs[i]); + ++i; ++j; + } + } + return ret; +} + +/*! + * \param Get function from op_map. + * \param op_map The OpMap. + * \param op The operator being called. + * \tparam ValueType the content value type. + * \return The result value map. + */ +template +ValueType GetFunc(const OpMap& op_map, + const Expr& op) { + if (const OpNode* opnode = op.as()) { + return op_map.get(GetRef(opnode), ValueType()); + } else { + return ValueType(); + } +} + +/*! + * \brief Preparation function for for pass scale forward. + * \param call The call node. + * \param out_scale_axes Possible scaling on axes of the output. + * \return The result scaling on axes of the input. + */ +using FForwardPrep = runtime::TypedPackedFunc< + Array (const Call& call, const AxesSet& out_scale_axes)>; + +/*! \brief Axis scale tuple. */ +class STupleNode : public Node { + public: + /*! \brief The value */ + Expr value; + /*! \brief The axes to scale, can be nullptr(means no-scaling) */ + AxesSet axes = NullValue(); + /*! \brief The scaling factor */ + Expr scale = NullValue(); + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("value", &value); + v->Visit("axes", &axes); + v->Visit("scale", &scale); + } + + static constexpr const char* _type_key = "relay.fold_scale_axis.STupleNode"; + TVM_DECLARE_NODE_TYPE_INFO(STupleNode, Node); +}; + +RELAY_DEFINE_NODE_REF(STuple, STupleNode, NodeRef); + +/*! + * \brief The transform function, transform an old call to + * a new one given the new args. + * \param ref_call Reference call node that represent the op and the types. + * \param expected_out_axes The scale axes allowed in the output. + * \param sargs The input arguments. + */ +using FForwardTransform = TypedPackedFunc< + STuple(const Call& ref_call, + const AxesSet& expected_out_axes, + const Array& sargs)>; + +//---------------------------------------------- +// Generic Visitors for FScaleAxisForward +//---------------------------------------------- +class FScaleAxisForwardPrep : private ExprVisitor { + public: + std::unordered_map + Prepare(const Expr& body) { + this->Update(body, NullValue()); + this->VisitExpr(body); + // flist is added in the Post-DFS order + // which is a special case of topological order. + // We reversely traverse the list to invoke the lazy functions. + // This act like a backprop of valid scale axis messages + for (auto it = flist_.rbegin(); it != flist_.rend(); ++it) { + (*it)(); + } + // return the created message; + return std::move(message_); + } + + private: + // The invoke list + std::vector > flist_; + // The message on each node. + std::unordered_map message_; + // Update the message stored at node. + void Update(const Expr& node, const AxesSet& axes) { + // We run intersection of messages: + // + // %y = multiply(%x, %scale) + // %z1 = conv2d(%y, %w) + // %z2 = exp(%y) + // + // Consider the above code example, + // because %z2 will propagate null to %y, + // the AxesSet on %y is also null, + // and the forward folding won't be triggered. + const Node* key = node.get(); + if (message_.count(key)) { + message_[key] = Intersect(message_[key], axes); + } else { + message_[key] = axes; + } + } + // Visitor pattern override. + void VisitExpr_(const LetNode* call) { + LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; + } + + void VisitExpr_(const FunctionNode* op) { + ExprVisitor::VisitExpr_(op); + auto flazy = [this, op] { + this->Update(op->body, NullValue()); + }; + flist_.push_back(flazy); + } + + void VisitExpr_(const CallNode* call) { + ExprVisitor::VisitExpr_(call); + // function to be lazily invoked + auto flazy = [this, call]() { + static const auto& fprep = + Op::GetAttr("FScaleAxisForwardPrep"); + // find the message send to this node. + auto it = message_.find(call); + AxesSet out_axes; + if (it != message_.end()) { + out_axes = it->second; + } else { + out_axes = NullValue(); + } + // pass the message back to all the children it references. + auto f = GetFunc(fprep, call->op); + if (f != nullptr) { + Array in_axes = f(GetRef(call), out_axes); + CHECK_EQ(in_axes.size(), call->args.size()); + for (size_t i = 0; i < call->args.size(); ++i) { + this->Update(call->args[i], in_axes[i]); + } + } else { + for (size_t i = 0; i < call->args.size(); ++i) { + this->Update(call->args[i], NullValue()); + } + } + }; + flist_.push_back(flazy); + } + + void VisitExpr_(const TupleNode* op) { + ExprVisitor::VisitExpr_(op); + // do not support pass scale through tuple for now. + auto flazy = [this, op]() { + for (const Expr& field : op->fields) { + this->Update(field, NullValue()); + } + }; + flist_.push_back(flazy); + } + + void VisitExpr_(const IfNode* op) { + ExprVisitor::VisitExpr_(op); + // do pass through condition + // by assigning NullValue + // it means fuse signal cannot pass + // through into these subexpressions. + auto flazy = [this, op]() { + this->Update(op->cond, NullValue()); + this->Update(op->true_branch, NullValue()); + this->Update(op->false_branch, NullValue()); + }; + flist_.push_back(flazy); + } +}; + +class FScaleAxisForwardTransform : private ExprMutator { + public: + // Transform expression. + Expr Transform(Expr expr) { + expected_scale_axes_ = + FScaleAxisForwardPrep().Prepare(expr); + return this->Mutate(expr); + } + + private: + // Valid axes on each node. + std::unordered_map expected_scale_axes_; + std::unordered_map scale_memo_; + // If user simply call mutate, + // then only Expr is returned and we cannot + // accept outstanding scales. + Expr VisitExpr(const Expr& expr) final { + Expr res = ExprMutator::VisitExpr(expr); + CHECK(!scale_memo_.count(expr.get())) + << "Outstanding scale"; + return res; + } + + STuple GetSTuple(const Expr& expr) { + Expr res = ExprMutator::VisitExpr(expr); + auto it = scale_memo_.find(expr.get()); + if (it != scale_memo_.end()) { + CHECK(it->second->value.same_as(res)); + return it->second; + } else { + auto node = make_node(); + node->value = res; + return STuple(node); + } + } + + Expr VisitExpr_(const CallNode* call_node) final { + static const auto& ftransform = + Op::GetAttr("FScaleAxisForwardTransform"); + auto new_op = this->Mutate(call_node->op); + bool has_scale = false; + bool unchanged = call_node->op.same_as(new_op); + + Array call_sargs; + Array call_args; + for (auto arg : call_node->args) { + STuple new_sarg = this->GetSTuple(arg); + unchanged &= new_sarg->value.same_as(arg); + if (new_sarg->axes.defined()) has_scale = true; + call_sargs.push_back(new_sarg); + call_args.push_back(new_sarg->value); + } + + // get expected scale axes. + AxesSet expected_out_axes; + auto axis_it = expected_scale_axes_.find(call_node); + if (axis_it != expected_scale_axes_.end()) { + expected_out_axes = axis_it->second; + } + // propagation function + auto f = GetFunc(ftransform, call_node->op); + if (f != nullptr) { + STuple sret = f(GetRef(call_node), expected_out_axes, call_sargs); + if (sret.defined()) { + if (sret->axes.defined()) { + scale_memo_[call_node] = sret; + } + return sret->value; + } + } + // normal path + CHECK(!has_scale) << "Outstanding scale, on op=" << call_node->op; + if (unchanged) { + return GetRef(call_node); + } else { + return CallNode::make( + new_op, call_args, call_node->attrs, call_node->type_args); + } + } +}; + +//---------------------------------------------- +// Per operator defs for FScaleAxisForward +//---------------------------------------------- + +// Intermediate operators +Array ReluForwardPrep(const Call& call, AxesSet out) { + return {out}; +} + +STuple ReluForwardTransform(const Call& ref_call, + const AxesSet& expected_axes, + const Array& sargs) { + if (!sargs[0]->axes.defined()) return STuple(); + // return transformed conv2d + auto rnode = make_node(); + rnode->value = CallNode::make( + ref_call->op, {sargs[0]->value}, ref_call->attrs, {}); + rnode->scale = sargs[0]->scale; + rnode->axes = sargs[0]->axes; + return STuple(rnode); +} + +RELAY_REGISTER_OP("nn.relu") +.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); + +RELAY_REGISTER_OP("nn.relu") +.set_attr("FScaleAxisForwardTransform", ReluForwardTransform); + +RELAY_REGISTER_OP("nn.leaky_relu") +.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); + +RELAY_REGISTER_OP("nn.leaky_relu") +.set_attr("FScaleAxisForwardTransform", ReluForwardTransform); + +// AddSub +Array AddSubForwardPrep(const Call& call, AxesSet out_axes) { + const auto* tlhs = call->args[0]->type_as(); + const auto* trhs = call->args[1]->type_as(); + + auto none = NullValue(); + if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) { + return {out_axes, none}; + } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) { + return {none, out_axes}; + } else { + return {none, none}; + } +} + +STuple AddSubForwardTransform(const Call& ref_call, + const AxesSet& expected_out_axes, + const Array& sargs) { + if (!sargs[0]->axes.defined() && !sargs[1]->axes.defined()) { + return STuple(); + } + const auto* tlhs = ref_call->args[0]->type_as(); + const auto* trhs = ref_call->args[1]->type_as(); + + auto rnode = make_node(); + if (sargs[0]->axes.defined()) { + CHECK(!sargs[1]->axes.defined()); + CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, sargs[0]->axes)); + Expr scale = ExpandBiasToMatchAxis( + sargs[0]->scale, tlhs->shape.size(), sargs[0]->axes); + Expr rhs = Divide(sargs[1]->value, scale); + rnode->value = CallNode::make(ref_call->op, {sargs[0]->value, rhs}, + ref_call->attrs, ref_call->type_args); + rnode->scale = sargs[0]->scale; + rnode->axes = sargs[0]->axes; + } else { + CHECK(sargs[1]->axes.defined()); + CHECK(sargs[0]->axes.defined()); + CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, sargs[1]->axes)); + Expr scale = ExpandBiasToMatchAxis( + sargs[1]->scale, trhs->shape.size(), sargs[1]->axes); + Expr lhs = Divide(sargs[0]->value, scale); + rnode->value = CallNode::make(ref_call->op, {lhs, sargs[1]->value}, + ref_call->attrs, ref_call->type_args); + rnode->scale = sargs[1]->scale; + rnode->axes = sargs[1]->axes; + } + return STuple(rnode); +} + +RELAY_REGISTER_OP("add") +.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); + +RELAY_REGISTER_OP("add") +.set_attr("FScaleAxisForwardTransform", AddSubForwardTransform); + +RELAY_REGISTER_OP("subtract") +.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); + +RELAY_REGISTER_OP("subtract") +.set_attr("FScaleAxisForwardTransform", AddSubForwardTransform); + +// Producer operators +// Multiply produces the scale-axis pair. +STuple MultiplyForwardTransform(const Call& ref_call, + const AxesSet& expected_out_axes, + const Array& sargs) { + if (!expected_out_axes.defined()) return STuple(); + // TODO(tvm-team) allow same axes accumulation + // not as important because it is less common in nn. + CHECK(!sargs[0]->axes.defined()); + CHECK(!sargs[1]->axes.defined()); + const auto* tlhs = ref_call->args[0]->type_as(); + const auto* trhs = ref_call->args[1]->type_as(); + + Expr lhs = sargs[0]->value; + Expr rhs = sargs[1]->value; + auto rnode = make_node(); + if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) { + rnode->value = lhs; + rnode->scale = rhs; + rnode->axes = expected_out_axes; + } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) { + rnode->value = rhs; + rnode->scale = lhs; + rnode->axes = expected_out_axes; + } + return STuple(rnode); +} + +RELAY_REGISTER_OP("multiply") +.set_attr("FScaleAxisForwardTransform", MultiplyForwardTransform); + +// Consumer operators +// Conv2D send out requirement of axis folding. +Array Conv2DForwardPrep(const Call& call, AxesSet out) { + // TODO(tvm-team) support general data layout + // by transforming weight + const auto* param = call->attrs.as(); + CHECK(param != nullptr); + Layout data_layout(param->data_layout); + Layout weight_layout(param->weight_layout); + int c_big_axis = data_layout.indexof('C'); + int c_small_axis = data_layout.indexof('c'); + const auto* tdata = call->args[0]->type_as(); + CHECK(tdata) << "require checked type"; + + CHECK_GE(c_big_axis, 0); + AxesSet data_axes = NullValue(); + // For now, we only support simple pattern (no folded weight/data) + // More general layout can be supported under the current framework. + // By using a unified layout transformation. + // We only need to change the Prep and Mutate function. + // + // only handle depthwise or full conv2d. + // TODO(tvm-team) handle grouped conv by reshape + bcast + bool is_depthwise_conv2d = + is_const_int(tdata->shape[c_big_axis], param->groups); + if (weight_layout.indexof('i') < 0 && + c_small_axis < 0 && + (param->groups == 1 || is_depthwise_conv2d)) { + data_axes = {c_big_axis}; + } + return {data_axes, NullValue()}; +} + +// Conv2D consumes the scale axis during transformation. +STuple Conv2DForwardTransform(const Call& ref_call, + const AxesSet& expected_axes, + const Array& sargs) { + // if data do not have scale, normal transform path. + STuple sdata = sargs[0]; + if (!sdata->scale.defined()) return STuple(); + CHECK(sdata->axes.defined()); + const auto* param = ref_call->attrs.as(); + CHECK(param != nullptr); + Layout data_layout(param->data_layout); + Layout weight_layout(param->weight_layout); + int c_big_axis = data_layout.indexof('C'); + CHECK_GE(c_big_axis, 0); + // For now, we only support simple pattern (no folded weight/data) + // TODO(tvm-team) support general data layout + CHECK_EQ(weight_layout.indexof('i'), -1); + CHECK(sdata->axes.size() == 1 && + c_big_axis == sdata->axes[0]->value); + int big_ic_axis = weight_layout.indexof('I'); + + const auto* tdata = ref_call->args[0]->type_as(); + // Check it must be depthwise or full conv2d. + bool is_depthwise_conv2d = + is_const_int(tdata->shape[c_big_axis], param->groups); + CHECK(param->groups == 1 || is_depthwise_conv2d); + + // match the ic_axis + Expr scale = ExpandBiasToMatchAxis( + sdata->scale, weight_layout.ndim(), {big_ic_axis}); + Expr weight = Multiply(sargs[1]->value, scale); + // return transformed conv2d + auto rnode = make_node(); + rnode->value = CallNode::make( + ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); + return STuple(rnode); +} + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FScaleAxisForwardTransform", Conv2DForwardTransform); + + +Expr ForwardFoldScaleAxis(Expr data) { + return FScaleAxisForwardTransform().Transform(data); +} + +// Expose the FoldScaleAxisFoward +TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis") +.set_body_typed(ForwardFoldScaleAxis); + +} // namespace fold_scale_axis +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h new file mode 100644 index 0000000000000..a395e74cdf0b8 --- /dev/null +++ b/src/relay/pass/pattern_util.h @@ -0,0 +1,123 @@ +/*! + * Copyright (c) 2018 by Contributors. + * + * \file tvm/relay/pass/pattern_util.h + * \brief Header of internal operator functions + * These can be used for writing passes. + */ +#ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_ +#define TVM_RELAY_PASS_PATTERN_UTIL_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Try to match lhs and rhs via broadcasting rule, such that: + * + * rhs matches the dimension of lhs specified by lhs_axes + * rhs's value equals 1 on rest of dimensions. + * + * \param tlhs The type of left operand (data) + * \param trhs The type right operand (bias) + * \param lhs_axes The axes on lhs to match. + * \param rhs_value A squeezed version of rhs which only contains matched dimension. + * \return Whether match is successful. + */ +inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, + const TensorTypeNode* trhs, + const Array& lhs_axes, + Expr* rhs_value = nullptr) { + if (tlhs->shape.size() < trhs->shape.size()) return false; + AttrsEqual equal; + size_t base = tlhs->shape.size() - trhs->shape.size(); + size_t j = 0; + + NodePtr squeeze_attrs; + if (rhs_value != nullptr) { + squeeze_attrs = make_node(); + } + + for (size_t i = 0; i < tlhs->shape.size(); ++i) { + if (j < lhs_axes.size() && i == static_cast(lhs_axes[j]->value)) { + if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) { + return false; + } + ++j; + } else if (i >= base) { + if (!is_const_int(trhs->shape[i - base], 1)) { + return false; + } + if (rhs_value != nullptr) { + squeeze_attrs->axis.push_back(static_cast(i - base)); + } + } + } + if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) { + static const Op& squeeze_op = Op::Get("squeeze"); + *rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {}); + } + return true; +} + +/*! + * \brief Expand 1D Tensor to match axis. + * + * The result bias can be used to add or multiply to + * the target Tensor on the specified axis via broadcasting rule. + * + * \param bias The bias. + * \param target_ndim target dimension. + * \param axes The axis on the output we want to match on. + */ +inline Expr ExpandBiasToMatchAxis(Expr bias, + int target_ndim, + const Array& axes) { + static const Op& expand_dims = Op::Get("expand_dims"); + for (size_t i = axes.size(); i != 0; --i) { + if (i == axes.size()) { + int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1; + if (num_pad_axis > 0) { + auto attrs = make_node(); + attrs->axis = i; + attrs->num_newaxis = static_cast(num_pad_axis); + bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); + } + } else { + int64_t diff = axes[i]->value - axes[i - 1]->value; + CHECK_GE(diff, 0L); + if (diff > 0) { + auto attrs = make_node(); + attrs->axis = i; + attrs->num_newaxis = static_cast(diff); + bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); + } + } + } + return bias; +} + +inline Expr Multiply(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("multiply"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + +inline Expr Divide(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("divide"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + + +inline Expr ReshapeLike(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("reshape_like"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + + + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 7c8eeef92c5d5..c1f6cdc639740 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -406,28 +406,57 @@ class TypeInferencer::Resolver : public ExprMutator { CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; + Expr new_e = ExprMutator::VisitExpr_(op); - if (!checked_type.same_as(new_e->checked_type_)) { + // new_call and new_var's code is only going to be valid for VarNode/CallNode. + // Compiler optimization will likely fold these away for other nodes. + CallNode* new_call =( + std::is_base_of::value ? + static_cast(new_e.node_.get()) : nullptr); + VarNode* new_var =( + std::is_base_of::value ? + static_cast(new_e.node_.get()) : nullptr); + + // check if we need update the new_e + bool need_update_type = !checked_type.same_as(new_e->checked_type_); + bool need_update_call = ( + std::is_base_of::value && + it->second.type_args.defined() && + !it->second.type_args.same_as(new_call->type_args)); + bool need_update_var = ( + std::is_base_of::value && + update_missing_type_annotation_ && + !new_var->type_annotation.defined()); + + if (!need_update_type && !need_update_var && !need_update_call) return new_e; + + if (!new_e.node_.unique()) { // Copy on write optimization // If new_e is an old expression, // we make a copy mutating an existing reference. - if (!new_e.node_.unique()) { - new_e = Expr(make_node(*new_e.as())); - } - new_e->checked_type_ = checked_type; + new_e = Expr(make_node(*new_e.as())); + new_call = ( + std::is_base_of::value ? + static_cast(new_e.node_.get()) : nullptr); + new_var = ( + std::is_base_of::value ? + static_cast(new_e.node_.get()) : nullptr); } - if (it->second.type_args.defined()) { - Call call = Downcast(new_e); - const CallNode* const_call_ref = call.operator->(); - CallNode* call_ref = const_cast(const_call_ref); - call_ref->type_args = it->second.type_args; + // attach the information. + if (need_update_type) { + new_e->checked_type_ = checked_type; + } - for (size_t i = 0; i < call->type_args.size(); i++) { - call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i])); + if (need_update_call) { + new_call->type_args = it->second.type_args; + for (size_t i = 0; i < new_call->type_args.size(); i++) { + new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i])); } } - + if (need_update_var) { + new_var->type_annotation = checked_type; + } return new_e; } @@ -438,6 +467,9 @@ class TypeInferencer::Resolver : public ExprMutator { private: const std::unordered_map& tmap_; TypeSolver* solver_; + // whether attach the checked type as type_annotation + // if original type anntation is missing. + bool update_missing_type_annotation_{true}; }; diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 2ee6f758f1004..427ac562fbc7c 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -55,8 +55,8 @@ def test_transpose_infer_type(): def test_squeeze_infer_type(): n, t, d = 1, 4, 1 x = relay.var("x", relay.TensorType((n, t, d), "float32")) - y = relay.squeeze(x, axes=(2,)) - assert "axes=" in y.astext() + y = relay.squeeze(x, axis=(2,)) + assert "axis=" in y.astext() yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType( (1, 4), "float32") @@ -64,7 +64,7 @@ def test_squeeze_infer_type(): n, t, d = 1, 4, 1 x = relay.var("x", relay.TensorType((n, t, d), "float32")) y = relay.squeeze(x) - assert "axes=" not in y.astext() + assert "axis=" not in y.astext() yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.TensorType( (4,), "float32") @@ -74,7 +74,7 @@ def test_squeeze_infer_type(): def test_squeeze_bad_axes_infer_type(): n, t, d = 1, 4, 1 x = relay.var("x", relay.TensorType((n, t, d), "float32")) - y = relay.squeeze(x, axes=(1,)) + y = relay.squeeze(x, axis=(1,)) yy = relay.ir_pass.infer_type(y) diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py new file mode 100644 index 0000000000000..7ce3b35efe460 --- /dev/null +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -0,0 +1,153 @@ +from tvm import relay + + +def test_fold_fwd_simple(): + """Simple testcase.""" + def before(x, conv_weight, in_bias, in_scale, channels): + args = [x, conv_weight, in_bias, in_scale] + in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2) + in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) + x = relay.multiply(x, in_scale) + x = relay.nn.relu(x) + x = relay.add(x, in_bias) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def expected(x, conv_weight, in_bias, in_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight, in_bias, in_scale] + in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2) + in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + x = relay.nn.relu(x) + in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + x = relay.add(x, in_bias) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[1] + weight = relay.var("weight") + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.var("in_scale", shape=(in_channels,)) + + y1 = before(x, weight, in_bias, in_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + y1_expected = expected(x, weight, in_bias, in_scale, channels) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 2) + + +def test_fold_fwd_dual_path(): + """scale axis being consumed by two consumers""" + def before(x, conv_weight, in_bias, in_scale, channels): + args = [x, conv_weight, in_bias, in_scale] + x = relay.multiply(in_scale, x) + x = relay.nn.relu(x) + x = relay.subtract(x, in_bias) + y1 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + weight_layout="HWOI", + groups=channels, + padding=(1, 1)) + y2 = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + weight_layout="HWOI", + groups=channels, + padding=(1, 1)) + z = relay.add(y1, y2) + return relay.Function(args, z) + + def expected(x, conv_weight, in_bias, in_scale, channels): + args = [x, conv_weight, in_bias, in_scale] + x = relay.nn.relu(x) + in_bias = relay.divide(in_bias, in_scale) + x = relay.subtract(x, in_bias) + y1 = relay.nn.conv2d(x, + relay.multiply(conv_weight, in_scale), + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + weight_layout="HWOI", + groups=channels, + padding=(1, 1)) + y2 = relay.nn.conv2d(x, + relay.multiply(conv_weight, in_scale), + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + weight_layout="HWOI", + groups=channels, + padding=(1, 1)) + z = relay.add(y1, y2) + return relay.Function(args, z) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[-1] + # test depthwise + assert in_channels == channels + weight = relay.var("weight") + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.var("in_scale", shape=(in_channels,)) + y1 = before(x, weight, in_bias, in_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_expected = expected(x, weight, in_bias, in_scale, channels) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 3), 3) + + +def test_fold_fwd_fail(): + """testcase where we canont fold""" + def before(x, conv_weight, in_bias, in_scale, channels): + x = relay.multiply(x, in_scale) + xx = relay.nn.leaky_relu(x, alpha=0.1) + y1 = relay.nn.conv2d(xx, conv_weight, + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + padding=(1, 1)) + z = relay.add(y1, x) + return relay.Function(relay.ir_pass.free_vars(z), z) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[-1] + # test depthwise + assert in_channels == channels + weight = relay.var("weight") + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.var("in_scale", shape=(in_channels,)) + y1 = before(x, weight, in_bias, in_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + assert relay.ir_pass.alpha_equal(y1, y1_folded) + + check((2, 11, 10, 4), 4) + + +if __name__ == "__main__": + test_fold_fwd_simple() + test_fold_fwd_dual_path() + test_fold_fwd_fail()