Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][[PASS] ForwardRewrite pass. #2124

Merged
merged 1 commit into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class TempExprNode : public ExprNode {
public:
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual Expr Realize() const = 0;

static constexpr const char* _type_key = "relay.TempExpr";
TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

// implementataions
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
Expand Down
40 changes: 40 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ class GenericOpMap {
*/
template <typename ValueType>
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Expr& expr, ValueType def_value) const;

private:
friend class OpRegistry;
Expand Down Expand Up @@ -313,6 +323,14 @@ class OpMap {
* \return the const reference to the content value.
*/
inline ValueType get(const Op& op, ValueType def_value) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline ValueType get(const Expr& expr, ValueType def_value) const;

private:
friend class Op;
Expand Down Expand Up @@ -496,6 +514,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
}
}

template <typename ValueType>
inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const {
CHECK(expr.defined());
if (const OpNode* op = expr.as<OpNode>()) {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
} else {
return value;
}
}

template <typename ValueType>
inline int OpMap<ValueType>::count(const Op& op) const {
return map_.count(op);
Expand All @@ -505,12 +538,19 @@ template <typename ValueType>
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
return map_[op];
}

template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Op& op,
ValueType def_value) const {
return map_.get<ValueType>(op, def_value);
}

template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Expr& expr,
ValueType def_value) const {
return map_.get<ValueType>(expr, def_value);
}

/*!
* \brief Check that an expression is a "primtive operator".
*
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
Schedule(const Attrs& attrs,
const Array<Tensor>& outs,
const Target& target)>;

/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
using FForwardRewrite = runtime::TypedPackedFunc<
Expr(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx)>;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
11 changes: 11 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \return The rewritten expression.
*/
Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class PackedFunc {
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*! \brief constructor from null */
PackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
Expand Down
Loading