Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Make check stricter by using Feature. Fixed multiple bugs. #6326

Merged
merged 8 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr.h>

#include <bitset>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -124,6 +125,13 @@ class FeatureSet {
*/
bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); }

/*!
* \brief Pretty Print the FeatureSet.
*
* \return a string representation.
*/
std::string Print() const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ToString sounds better?


private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
Expand Down Expand Up @@ -160,6 +168,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param fs The feature set of the program.
*/
void CheckFeature(const RelayExpr& expr, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param mod The module.
* \param fs The feature set of the program.
*/
void CheckFeature(const IRModule& mod, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param mod The module.
* \param fs The feature set of the program.
*/
inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) {
CheckFeature(expr, fs);
CheckFeature(mod, fs);
}

} // namespace relay
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm();
*/
TVM_DLL Pass ToANormalForm();

/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule, TypeCall
from tvm import relay

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
Expand Down Expand Up @@ -1237,6 +1238,7 @@ def __init__(self, mod=None):
mod = IRModule()
self.mod = mod
self.load_prelude()
self.mod = relay.transform.ToANormalForm()(self.mod)

def get_name(self, canonical, dtype):
"""Get name corresponding to the canonical name"""
Expand Down
51 changes: 47 additions & 4 deletions src/relay/analysis/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,43 @@ FeatureSet DetectFeature(const Expr& expr) {
return fd.fs;
}

std::string FeatureSet::Print() const {
std::string ret;
ret += "[";
size_t detected = 0;
#define DETECT_FEATURE(FEATURE_NAME) \
++detected; \
if (bs_[FEATURE_NAME]) { \
ret += #FEATURE_NAME; \
ret += ", "; \
}
DETECT_FEATURE(fVar);
DETECT_FEATURE(fGlobalVar);
DETECT_FEATURE(fConstant);
DETECT_FEATURE(fTuple);
DETECT_FEATURE(fTupleGetItem);
DETECT_FEATURE(fFunction);
DETECT_FEATURE(fOp);
DETECT_FEATURE(fCall);
DETECT_FEATURE(fLet);
DETECT_FEATURE(fIf);
DETECT_FEATURE(fRefCreate);
DETECT_FEATURE(fRefRead);
DETECT_FEATURE(fRefWrite);
DETECT_FEATURE(fConstructor);
DETECT_FEATURE(fMatch);
DETECT_FEATURE(fGraph);
DETECT_FEATURE(fLetRec);
#undef DETECT_FEATURE
CHECK(detected == feature_count) << "some feature not printed";
ret += "]";
return ret;
}

FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
return fs;
}
Expand All @@ -106,5 +137,17 @@ Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod)

TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);

void CheckFeature(const Expr& expr, const FeatureSet& fs) {
auto dfs = DetectFeature(expr);
CHECK(dfs.is_subset_of(fs)) << AsText(expr, false)
<< "\nhas unsupported feature: " << (dfs - fs).Print();
}

void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
for (const auto& f : mod->functions) {
CheckFeature(f.second, fs);
}
}

} // namespace relay
} // namespace tvm
55 changes: 42 additions & 13 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>

Expand Down Expand Up @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) {
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();

if (mod.defined() && (x)) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
Expand Down Expand Up @@ -354,9 +355,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (forward_type.as<TensorTypeNode>()) {
auto ret = f(e);
auto ret = ll->Push(f(e));
ret->checked_type_ = tf(forward_type);
return ret;
return std::move(ret);
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
tvm::Array<Type> types;
Expand All @@ -365,7 +366,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = Tuple(fields);
auto ret = ll->Push(Tuple(fields));
ret->checked_type_ = TupleType(types);
return std::move(ret);
} else {
Expand Down Expand Up @@ -395,9 +396,10 @@ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, L
}
}

// TODO(@M.K.): why take Expr?
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); };
auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); };
auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
return LiftTensor(rev, rev_type, forward_type, e, ll);
}
Expand All @@ -411,14 +413,14 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {

/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); };
auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); };
auto grad_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(grad, grad_type, forward_type, e, ll);
}

void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad)));
ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
Expand Down Expand Up @@ -448,14 +450,32 @@ struct ReverseAD : ExprMutator {
throw;
}

Expr Remap(const Expr& e) {
struct Remapper : ExprMutator {
std::shared_ptr<ADVarMap> ad_vars;
LetList* ll;
Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {}
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (ad_vars->count(var_ref) == 0) {
return std::move(var_ref);
} else {
return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll);
}
}
};
return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); });
}

Expr VisitCheckpoint(const CallNode* call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto x_var = ll->Push(Remap(x));
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
Expand Down Expand Up @@ -508,16 +528,19 @@ struct ReverseAD : ExprMutator {
return Call(bpv, {});
}),
TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreate(ZerosLike(e)));
return LetList::With([&](LetList* ll) {
Expr e = ll->Push(GetRef<Expr>(op));
return Pair(e, RefCreate(ZerosLike(e)));
});
}

Expr VisitExpr_(const IfNode* op) final {
Expand All @@ -528,7 +551,7 @@ struct ReverseAD : ExprMutator {
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (!ad_vars->count(var_ref)) {
if (ad_vars->count(var_ref) == 0) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}
Expand Down Expand Up @@ -568,6 +591,10 @@ bool MissingGrad(const Expr& e) {
}

Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
CheckFeature(re, FeatureSet::All() - fGraph);
if (mod.defined()) {
CheckFeature(mod.value(), FeatureSet::All() - fGraph);
}
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
Expand Down Expand Up @@ -619,7 +646,9 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
CheckFeature(ret, FeatureSet::All() - fGraph);
return std::move(ret);
}

TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
Expand Down
Loading