Skip to content

Commit

Permalink
[Relay] Make check stricter by using Feature. Fixed multiple bugs. (a…
Browse files Browse the repository at this point in the history
…pache#6326)

* save

lint

lint

lint

fix lint

lint

update

lint

save

save

save

lint

format

format

save

save

fix

use a form more suitable for numeric check

save

* save

* save

* lint

* save

* lint

* fix

* fix
  • Loading branch information
MarisaKirisame authored and trevor-m committed Sep 3, 2020
1 parent 337179f commit 0fb51f3
Show file tree
Hide file tree
Showing 17 changed files with 291 additions and 170 deletions.
34 changes: 34 additions & 0 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr.h>

#include <bitset>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -124,6 +125,11 @@ class FeatureSet {
*/
bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); }

/*!
* \brief return a string representation.
*/
std::string ToString() const;

private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
Expand Down Expand Up @@ -160,6 +166,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param fs The feature set of the program.
*/
void CheckFeature(const RelayExpr& expr, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param mod The module.
* \param fs The feature set of the program.
*/
void CheckFeature(const IRModule& mod, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param mod The module.
* \param fs The feature set of the program.
*/
inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) {
CheckFeature(expr, fs);
CheckFeature(mod, fs);
}

} // namespace relay
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm();
*/
TVM_DLL Pass ToANormalForm();

/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
Expand Down
16 changes: 9 additions & 7 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule, TypeCall
from tvm.relay.transform import ToANormalFormExpr

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
Expand Down Expand Up @@ -204,7 +205,6 @@ def define_tensor_concatenate(self):
self.prelude.mod[concat_var] = \
Function([x, y], Match(x, [case], False), tensor_type_var(), [])


def define_tensor_expand_dims(self):
"""Defines a function to grow a tensor_t's rank by adding one dimension in front
of the original tensor_t.
Expand Down Expand Up @@ -511,8 +511,9 @@ def define_tensor_array_stack(self):
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims))
output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
self.prelude.mod[stack_var] = Function([tensor_array], tensors,
output_tensor_type_var(), [])
self.prelude.mod[stack_var] = \
Function([tensor_array], tensors,
output_tensor_type_var(), [])

def define_tensor_array_gather(self):
"""Defines a function to return the selected values in a tensor array as tensor_t.
Expand Down Expand Up @@ -809,7 +810,7 @@ def define_tensor_concat(self):
tensor4_var(op.concatenate([t41, t42], axis=0)))],
False))
# op.concatenate does not support tensor with rank higher than 4
self.prelude.mod[concat_var] =\
self.prelude.mod[concat_var] = \
Function([x, y], Match(x, [tensor1_case,
tensor2_case,
tensor3_case,
Expand Down Expand Up @@ -1167,7 +1168,7 @@ def define_tensor_array_gather(self):
current = Var("current", scalar_type('int32'))
limit = Var("limit", scalar_type('int32'))
indices_ = Var('indices_', TensorType([Any()], 'int32'))
helper_body =\
helper_body = \
If(equal(current, const(0)),
stack_var(accu),
helper_var(
Expand All @@ -1187,7 +1188,7 @@ def define_tensor_array_gather(self):
indices_shape = op.shape_of(indices)
limit = op.take(indices_shape, const(0))
body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
self.prelude.mod[gather_var] =\
self.prelude.mod[gather_var] = \
Function([tensor_array, indices], body, tensor_type_var(), [])

def define_tensor_array_stack(self):
Expand All @@ -1205,7 +1206,8 @@ def define_tensor_array_stack(self):
tensors = self.prelude.foldl(concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims))
self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), [])
self.prelude.mod[stack_var] = \
Function([tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), [])

def register(self):
"""Register all tensor array ops in Prelude"""
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,26 @@ def ToANormalForm():
Returns
-------
ret: Union[tvm.transform.Pass, tvm.relay.Expr]
ret : Union[tvm.transform.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return _ffi_api.ToANormalForm()

def ToANormalFormExpr(e):
"""ToANormalForm, but on expression level.
Parameters
----------
e : Expr
The graph expression.
Returns
-------
ret : Expr
The transformed expresion.
"""
return _ffi_api.ToANormalFormExpr(e)

def ToBasicBlockNormalForm():
"""Turn an expression to Basic Block Normal Form.
We define a block as a group of expressions implied by the scope structure.
Expand Down
51 changes: 47 additions & 4 deletions src/relay/analysis/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,43 @@ FeatureSet DetectFeature(const Expr& expr) {
return fd.fs;
}

std::string FeatureSet::ToString() const {
std::string ret;
ret += "[";
size_t detected = 0;
#define DETECT_FEATURE(FEATURE_NAME) \
++detected; \
if (bs_[FEATURE_NAME]) { \
ret += #FEATURE_NAME; \
ret += ", "; \
}
DETECT_FEATURE(fVar);
DETECT_FEATURE(fGlobalVar);
DETECT_FEATURE(fConstant);
DETECT_FEATURE(fTuple);
DETECT_FEATURE(fTupleGetItem);
DETECT_FEATURE(fFunction);
DETECT_FEATURE(fOp);
DETECT_FEATURE(fCall);
DETECT_FEATURE(fLet);
DETECT_FEATURE(fIf);
DETECT_FEATURE(fRefCreate);
DETECT_FEATURE(fRefRead);
DETECT_FEATURE(fRefWrite);
DETECT_FEATURE(fConstructor);
DETECT_FEATURE(fMatch);
DETECT_FEATURE(fGraph);
DETECT_FEATURE(fLetRec);
#undef DETECT_FEATURE
CHECK(detected == feature_count) << "some feature not printed";
ret += "]";
return ret;
}

FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
return fs;
}
Expand All @@ -106,5 +137,17 @@ Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod)

TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);

void CheckFeature(const Expr& expr, const FeatureSet& fs) {
auto dfs = DetectFeature(expr);
CHECK(dfs.is_subset_of(fs)) << AsText(expr, false)
<< "\nhas unsupported feature: " << (dfs - fs).ToString();
}

void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
for (const auto& f : mod->functions) {
CheckFeature(f.second, fs);
}
}

} // namespace relay
} // namespace tvm
55 changes: 42 additions & 13 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>

Expand Down Expand Up @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) {
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();

if (mod.defined() && (x)) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
Expand Down Expand Up @@ -354,9 +355,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (forward_type.as<TensorTypeNode>()) {
auto ret = f(e);
auto ret = ll->Push(f(e));
ret->checked_type_ = tf(forward_type);
return ret;
return std::move(ret);
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
tvm::Array<Type> types;
Expand All @@ -365,7 +366,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = Tuple(fields);
auto ret = ll->Push(Tuple(fields));
ret->checked_type_ = TupleType(types);
return std::move(ret);
} else {
Expand Down Expand Up @@ -395,9 +396,10 @@ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, L
}
}

// TODO(@M.K.): why take Expr?
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); };
auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); };
auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
return LiftTensor(rev, rev_type, forward_type, e, ll);
}
Expand All @@ -411,14 +413,14 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {

/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); };
auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); };
auto grad_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(grad, grad_type, forward_type, e, ll);
}

void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad)));
ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
Expand Down Expand Up @@ -448,14 +450,32 @@ struct ReverseAD : ExprMutator {
throw;
}

Expr Remap(const Expr& e) {
struct Remapper : ExprMutator {
std::shared_ptr<ADVarMap> ad_vars;
LetList* ll;
Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {}
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (ad_vars->count(var_ref) == 0) {
return std::move(var_ref);
} else {
return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll);
}
}
};
return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); });
}

Expr VisitCheckpoint(const CallNode* call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto x_var = ll->Push(Remap(x));
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
Expand Down Expand Up @@ -508,16 +528,19 @@ struct ReverseAD : ExprMutator {
return Call(bpv, {});
}),
TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreate(ZerosLike(e)));
return LetList::With([&](LetList* ll) {
Expr e = ll->Push(GetRef<Expr>(op));
return Pair(e, RefCreate(ZerosLike(e)));
});
}

Expr VisitExpr_(const IfNode* op) final {
Expand All @@ -528,7 +551,7 @@ struct ReverseAD : ExprMutator {
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (!ad_vars->count(var_ref)) {
if (ad_vars->count(var_ref) == 0) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}
Expand Down Expand Up @@ -568,6 +591,10 @@ bool MissingGrad(const Expr& e) {
}

Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
CheckFeature(re, FeatureSet::All() - fGraph);
if (mod.defined()) {
CheckFeature(mod.value(), FeatureSet::All() - fGraph);
}
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
Expand Down Expand Up @@ -619,7 +646,9 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
CheckFeature(ret, FeatureSet::All() - fGraph);
return std::move(ret);
}

TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
Expand Down
Loading

0 comments on commit 0fb51f3

Please sign in to comment.