Skip to content

Commit

Permalink
[RELAY][[PASS] Consolidate ForwardRewrite pass. (#2124)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Nov 16, 2018
1 parent de3b63a commit 59c70a0
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 178 deletions.
26 changes: 26 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class TempExprNode : public ExprNode {
public:
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual Expr Realize() const = 0;

static constexpr const char* _type_key = "relay.TempExpr";
TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
Expand Down
40 changes: 40 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ class GenericOpMap {
*/
template <typename ValueType>
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Expr& expr, ValueType def_value) const;

private:
friend class OpRegistry;
Expand Down Expand Up @@ -313,6 +323,14 @@ class OpMap {
* \return the const reference to the content value.
*/
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline ValueType get(const Expr& expr, ValueType def_value) const;

private:
friend class Op;
Expand Down Expand Up @@ -496,6 +514,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
}
}

template <typename ValueType>
inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
CHECK(expr.defined());
if (const OpNode* op = expr.as<OpNode>()) {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
} else {
return value;
}
}

template <typename ValueType>
inline int OpMap<ValueType>::count(const Op& op) const {
return map_.count(op);
Expand All @@ -505,12 +538,19 @@ template <typename ValueType>
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
return map_[op];
}

template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Op& op,
ValueType def_value) const {
return map_.get<ValueType>(op, def_value);
}

template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Expr& expr,
ValueType def_value) const {
return map_.get<ValueType>(expr, def_value);
}

/*!
* \brief Check that an expression is a "primtive operator".
*
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Attrs& attrs,
const Array<Tensor>& outs,
const Target& target)>;

/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
11 changes: 11 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class PackedFunc {
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*! \brief constructor from null */
PackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
Expand Down
Loading

0 comments on commit 59c70a0

Please sign in to comment.