Skip to content

Commit

Permalink
revert gradient pass
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jun 28, 2019
1 parent e165fd8 commit 28f424f
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 128 deletions.
7 changes: 0 additions & 7 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,6 @@ TVM_DLL Pass CanonicalizeCast();
*/
TVM_DLL Pass EtaExpand();

/*!
* \brief Compute the automatic differentiation of the Relay IR.
*
* \return The pass.
*/
TVM_DLL Pass Gradient();

/*!
* \brief This is a helper function that runs a some optimization passes on
* a certain expression and returns the optimized version. With the help of this
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,37 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)


def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
returning a function that calculate the original result,
paired with gradient of the input.
Parameters
----------
expr : tvm.relay.Expr
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
mode : Optional[String]
The mode of the automatic differentiation algorithm.
'first_order' only work on first order code, but will not produce reference nor closure.
'higher_order' work on all code using reference and closure.
Returns
-------
expr : tvm.relay.Expr
The transformed expression.
"""
if mode == 'first_order':
return _ir_pass.first_order_gradient(expr, mod)
elif mode == 'higher_order':
return _ir_pass.gradient(expr, mod)
else:
raise Exception('unknown mode')


def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down
24 changes: 0 additions & 24 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,30 +473,6 @@ def CanonicalizeCast():
return _transform.CanonicalizeCast()


def Gradient(mode='higher_order'):
"""
Compute the gradient of the expressions in an input module.
Parameters
----------
mode: Optional[String]
The mode of the automatic differentiation algorithm.
'first_order' indicates the computation of the first order gradient,
which does not produce reference or closure. 'higher_order' can work on
all Relay expressions including those with reference and closure.
Returns
-------
ret: tvm.relay.Pass
The registered pass that computes the gradient.
"""
if mode == 'first_order':
return _transform.FirstOrderGradient()
if mode == 'higher_order':
return _transform.Gradient()
raise TypeError("Unknow mode: {}".format(mode))


def OptimizeOnExpr(expr, passes):
"""Perform optimization passes on an expressioin.
Expand Down
65 changes: 28 additions & 37 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ using namespace tvm::runtime;
* There are multiple implementation of AD in relay, with different characteristic.
* However, they all transform the input expr according to WithGradientType.
*/
Type WithGradientType(const Type&);

/*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow.
*/
Expr FirstOrderGradient(const Expr& e, const Module& mod);

Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
Expand All @@ -71,6 +78,15 @@ Type WithGradientType(const Type& t) {
TupleTypeNode::make(ty->arg_types)}), {}, {});
}

//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) {
return mod->Lookup(GetRef<GlobalVar>(x))->body;
} else {
return e;
}
}

/*! \brief A fragment of the program being built by the automatic differentation
* pass.
*/
Expand Down Expand Up @@ -193,22 +209,18 @@ Type GradRetType(const Function& f) {
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}

/*!
* \brief Return an expression that represents differentiation of an
* input expression (according to WithGradientType).
* This version only works on first order code without control flow.
*/
Expr FirstOrderGradient(const Expr& re) {
Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// Currently we first remove any global functions for the first
// order case.
auto f = re.as<FunctionNode>();
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "FOWithGradient expects its argument to be a function: " << f;
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";

// We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) {
FirstOrderReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(re);
ADValue rev = reverse_ad(e);
std::vector<ADValue> args;
for (const auto& p : f->params) {
args.push_back(std::make_shared<ADTensor>(ll, p));
Expand All @@ -234,6 +246,9 @@ Expr FirstOrderGradient(const Expr& re) {
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body_typed(FirstOrderGradient);

struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
Expand Down Expand Up @@ -306,13 +321,14 @@ Expr BPEmpty() {
return RefCreateNode::make(unitF);
}

Expr Gradient(const Expr& re) {
auto f = re.as<FunctionNode>();
Expr Gradient(const Expr& re, const Module& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(re);
Expr rev = ReverseAD(bp)(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
Expand All @@ -329,33 +345,8 @@ Expr Gradient(const Expr& re) {
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

namespace transform {

Pass Gradient() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(Gradient(f));
};
return CreateFunctionPass(pass_func, 3, "Gradient",
{ir::StringImm::make("InferType")});
}

TVM_REGISTER_API("relay._transform.Gradient")
TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body_typed(Gradient);

Pass FirstOrderGradient() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FirstOrderGradient(f));
};
return CreateFunctionPass(pass_func, 3, "FirstOrderGradient",
{ir::StringImm::make("InferType")});
}

TVM_REGISTER_API("relay._transform.FirstOrderGradient")
.set_body_typed(FirstOrderGradient);

} // namespace transform

} // namespace relay
} // namespace tvm
12 changes: 3 additions & 9 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,18 @@
import tvm
import numpy as np
from tvm import relay
from tvm.relay import transform
from tvm.relay.ir_pass import gradient, infer_type
from tvm.relay.testing import ctx_list

def run_gradient_pass(func, mode='higher_order'):
return transform.OptimizeOnExpr(func, [transform.Gradient(mode),
transform.InferType()])

def sigmoid(x):
one = np.ones_like(x)
return one / (one + np.exp(-x))


def relu(x):
x_copy = np.copy(x)
np.maximum(x_copy, 0, x_copy)
return x_copy


def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
Expand All @@ -47,7 +41,7 @@ def check_single_op(opfunc, ref):
data = np.random.rand(*shape).astype(dtype)
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
bwd_func = run_gradient_pass(fwd_func)
bwd_func = infer_type(gradient(fwd_func))

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
Expand Down Expand Up @@ -79,7 +73,7 @@ def check_binary_op(opfunc, ref):
y_data = np.random.rand(*s).astype(t.dtype)
ref_grad0, ref_grad1 = ref(x_data, y_data)
fwd_func = relay.Function([x, y], z)
bwd_func = run_gradient_pass(fwd_func)
bwd_func = infer_type(gradient(fwd_func))

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
Expand Down
42 changes: 15 additions & 27 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
from nose.tools import nottest

import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr


def run_gradient_pass(func, mod=None, mode='higher_order'):
if mod:
mod[mod.entry_func] = func
mod = transform.Gradient(mode)(mod)
return mod[mod.entry_func]
else:
return transform.OptimizeOnExpr(func, [transform.Gradient(mode),
transform.InferType()])
import numpy as np


def rand(dtype='float32', *shape):
Expand All @@ -44,7 +34,7 @@ def test_id():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
Expand All @@ -59,7 +49,7 @@ def test_add():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x + x)
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
Expand All @@ -75,7 +65,7 @@ def test_temp_add():
x = relay.var("x", t)
y = x + x
func = relay.Function([x], y + y)
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
Expand All @@ -90,7 +80,7 @@ def test_sub():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x - x)
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
Expand All @@ -113,7 +103,7 @@ def test_broadcast_add():
x = relay.var("x", t1)
y = relay.var("y", t2)
func = relay.Function([x, y], x + y)
full_func = run_gradient_pass(func)
full_func = relay.ir_pass.infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
relay.TupleType([t1, t2])]))
Expand All @@ -140,7 +130,7 @@ def test_broadcast_subtract():
x = relay.var("x", t1)
y = relay.var("y", t2)
func = relay.Function([x, y], x - y)
full_func = run_gradient_pass(func)
full_func = relay.ir_pass.infer_type(gradient(func))
assert full_func.checked_type == relay.FuncType([t1, t2],
relay.TupleType([relay.TensorType(expected_forward.shape, dtype),
relay.TupleType([t1, t2])]))
Expand All @@ -165,7 +155,7 @@ def test_tuple():
relay.TupleGetItem(tup, 0) +
relay.TupleGetItem(tup, 1) -
relay.TupleGetItem(tup, 2)))
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])]))
x_nd = rand(dtype, *shape)
y_nd = rand(dtype, *shape)
Expand All @@ -182,10 +172,7 @@ def test_tuple():
tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))


@nottest
def test_pow():
# This pass is disabled for now since the gradient pass does not really
# support polymophism yet.
mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)
Expand All @@ -196,7 +183,7 @@ def test_pow():
double = relay.Function([x], x + x)
i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
back_func = run_gradient_pass(func, mod=mod)
back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
i_nd = rand(dtype, *shape)
ex = create_executor(mod=mod)
Expand All @@ -216,7 +203,7 @@ def test_ref():
body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(x), body)
func = relay.Function([x], body)
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape)
ex = create_executor()
Expand All @@ -231,10 +218,11 @@ def test_square_second_order():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x * x)
back_func = run_gradient_pass(func)
back_func = relay.ir_pass.infer_type(gradient(func))
y = relay.var("y", t)
back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0))
back_back_func = run_gradient_pass(back_func_adjusted)
back_func_adjusted = relay.ir_pass.infer_type(back_func_adjusted)
back_back_func = relay.ir_pass.infer_type(gradient(back_func_adjusted))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
x_nd = rand(dtype, *shape)
ex = create_executor()
Expand Down
Loading

0 comments on commit 28f424f

Please sign in to comment.