Skip to content

Commit

Permalink
[RELAY][PASS] FoldScaleAxis Forward (apache#2020)
Browse files Browse the repository at this point in the history
* [RELAY][PASS] FoldScaleAxis Forward

* Introduce helper function type_as

* Update per review comment

* Fix according to comments
  • Loading branch information
tqchen authored and AWS Neo committed Feb 20, 2019
1 parent 89601c6 commit ee7238e
Show file tree
Hide file tree
Showing 14 changed files with 1,001 additions and 67 deletions.
13 changes: 7 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,16 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
Array<IndexExpr> axes;
// use axis to make the name numpy compatible.
Array<Integer> axis;

TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
TVM_ATTR_FIELD(axes)
.describe("The axes to squeeze in the input tensor."
"If `axes = []`, all axis of dimension 1 get squeezed;"
TVM_ATTR_FIELD(axis)
.describe("The axis to squeeze in the input tensor."
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axes does not has dimension 1.")
.set_default(Array<IndexExpr>({}));
"It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
}
}; // struct SqueezeAttrs

Expand Down
26 changes: 26 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ class ExprNode : public RelayNode {
"field for this node";
return this->checked_type_;
}
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
*
* \note This function will thrown an error if the node type
* of this Expr is not TTypeNode.
*
* \return The corresponding TTypeNode pointer.
* \tparam The specific TypeNode we look for.
*/
template<typename TTypeNode>
inline const TTypeNode* type_as() const;

static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
Expand Down Expand Up @@ -391,6 +403,20 @@ class TupleGetItemNode : public ExprNode {

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
"TType must be a special case of type");
CHECK(checked_type_.defined())
<< "Type inference for this Expr has not completed";
const TTypeNode* node = checked_type_.as<TTypeNode>();
CHECK(node != nullptr)
<< "Expected type to be " << TTypeNode::_type_key
<< ", but get " << checked_type_->type_key();
return node;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
14 changes: 11 additions & 3 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ class ExprVisitor
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expr Mutate(const Expr& expr);
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
}
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
Expand All @@ -161,15 +168,16 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);

private:
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};
Expand Down
43 changes: 27 additions & 16 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,42 @@ class OpNode : public relay::ExprNode {
v->Visit("support_level", &support_level);
}

/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
bool IsPrimitiveOp() const {
if (is_primitive_ != -1) return is_primitive_ != 0;
is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0;
return is_primitive_ != 0;
}

static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);

private:
// friend class
friend class GenericOpMap;
friend class OpRegistry;
friend bool IsPrimitiveOp(const Expr&);
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
// whether this is a primitive op. -1 means unknown.
mutable int is_primitive_{-1};
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
};

/*!
Expand Down Expand Up @@ -497,22 +523,7 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();

if (!op) {
return false;
}

const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;

const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}

return true;
return op != nullptr && op->IsPrimitiveOp();
}

} // namespace relay
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .expr import Expr
from .ty import Type


def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.
Expand All @@ -30,6 +31,23 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env)


def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
"""
return _ir_pass.forward_fold_scale_axis(expr)


def well_formed(expr):
"""Check that each Var is only bound once (well formed).
Expand Down Expand Up @@ -149,6 +167,7 @@ def alpha_equal(lhs, rhs):
"""
return bool(_make._alpha_equal(lhs, rhs))


def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
Expand All @@ -170,6 +189,7 @@ def graph_equal(lhs, rhs):
"""
return bool(_make._graph_equal(lhs, rhs))


def structural_hash(value):
"""Hash a Relay expression structurally.
Expand Down
18 changes: 8 additions & 10 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,25 @@ def transpose(data, axes=None):
return _make.transpose(data, list(axes))


def squeeze(data, axes=None):
def squeeze(data, axis=None):
"""Squeeze axes in the array.
Parameters
----------
data : relay.Expr
data : tvm.relay.Expr
The input data to the operator.
axes : None or List[int]
Axes to remove.
If axes = [] or = None, remove all axis of dimensions 1.
Otherwise, remove all axis in axes.
If any axis in axes has dimension that does not equal 1, it is an error.
axis : None or List[int]
The set of axes to remove.
If axis = None, remove all axis of dimensions 1.
If any specified axis has dimension that does not equal 1, it is an error.
Returns
-------
result : relay.Expr
result : tvm.relay.Expr
The squeezed result.
"""
axes = axes or []
return _make.squeeze(data, list(axes))
return _make.squeeze(data, axis)


def reshape(data, newshape):
Expand Down
20 changes: 15 additions & 5 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,23 @@ class AlphaEqualHandler:
if (const CallNode* rhs = other.as<CallNode>()) {
if (!ExprEqual(lhs->op, rhs->op)) return false;
if (lhs->args.size() != rhs->args.size()) return false;
if (lhs->type_args.size() != rhs->type_args.size()) return false;

// skip type_args check for primitive ops.
bool is_primitive = IsPrimitiveOp(lhs->op);
if (!is_primitive) {
if (lhs->type_args.size() != rhs->type_args.size()) {
return false;
}
}
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
if (!ExprEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;

if (!is_primitive) {
for (size_t i = 0; i < lhs->type_args.size(); ++i) {
if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
}
}
return AttrEqual(lhs->attrs, rhs->attrs);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
namespace tvm {
namespace relay {

Expr ExprMutator::Mutate(const Expr& expr) {
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
Expr new_expr = ExprMutator::VisitExpr(expr);
Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
Expand Down
14 changes: 6 additions & 8 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,9 @@ Examples::
TVM_REGISTER_NODE_TYPE(SqueezeAttrs);

Expr MakeSqueeze(Expr data,
Array<IndexExpr> axes) {
Array<Integer> axis) {
auto attrs = make_node<SqueezeAttrs>();
attrs->axes = std::move(axes);
attrs->axis = std::move(axis);
static const Op& op = Op::Get("squeeze");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expand All @@ -785,8 +785,8 @@ bool SqueezeRel(const Array<Type>& types,
const auto* param = attrs.as<SqueezeAttrs>();
CHECK(param != nullptr);
std::vector<IndexExpr> result_shape;
// if axes is empty, squeeze all axes of dimension 1
if (param->axes.size() == 0) {
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
Expand All @@ -800,10 +800,8 @@ bool SqueezeRel(const Array<Type>& types,
for (const auto& e : data->shape) {
original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
}
for (const auto& e : param->axes) {
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr);
original_shape.at(*axis_ptr).second = false;
for (const auto& e : param->axis) {
original_shape.at(e->value).second = false;
}
for (const auto p : original_shape) {
if (p.second) {
Expand Down
Loading

0 comments on commit ee7238e

Please sign in to comment.