Skip to content

Commit

Permalink
Make first order gradient graphs more efficient
Browse files Browse the repository at this point in the history
Previously, nodes are visited as often as they are used and each time a
derivative is computed. Only at the leaves were the contributions of
everything added. This patch changes this to add at any node that is
used several times.
  • Loading branch information
t-vi committed Jun 30, 2020
1 parent 2e04393 commit 35d00aa
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pylint: disable=invalid-name
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
import collections
import numpy as np

import tvm
Expand Down Expand Up @@ -135,3 +136,18 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me

def rand(dtype, *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))


def count_ops(expr):
"""count number of times a given op is called in the graph"""
class OpCounter(tvm.relay.ExprVisitor):
def visit_call(self, call):
if hasattr(call, 'op'):
self.node_counter[call.op.name] += 1
return super().visit_call(call)
def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter
return OpCounter().count(expr)
18 changes: 13 additions & 5 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,24 @@ struct ADFunction : ADValueNode {
};

struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
using TBase = ExprFunctor<ADValue(const Expr&)>;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
std::unordered_map<Expr, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
LetList* ll;

FirstOrderReverseAD(LetList* ll) : ll(ll) {}

ADValue VisitExpr(const Expr& n) final {
if (env.count(n)) {
return env.at(n);
}
auto ret = TBase::VisitExpr(n);
env[n] = ret;
return ret;
}

ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined";
Expand Down Expand Up @@ -268,10 +278,8 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
});
}

ADValue VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return env.at(v);
}
// Var will always be in env, handled in VisitExpr (without _), so we don't need
// to implement its VisitExpr_.
};

Type GradRetType(const Function& f) {
Expand Down
32 changes: 18 additions & 14 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import collections
import numpy as np
import pytest

import tvm
from tvm import te
Expand All @@ -23,7 +25,7 @@
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand, count_ops
import tvm.relay.op as op


Expand Down Expand Up @@ -309,17 +311,19 @@ def test_concat():
# no value validation as concatenate has dummy gradient right now.


def test_no_duplication():
x = tvm.relay.Var('x', type_annotation=tvm.relay.TensorType([12, 12]))
y = tvm.relay.Var('y', type_annotation=tvm.relay.TensorType([12, 12]))
xy = tvm.relay.nn.dense(x, y)

m = tvm.relay.sum(xy, keepdims=True)
s = tvm.relay.sum(xy - m)
fn = tvm.relay.Function([x,y], s)
fn = run_infer_type(fn)
gr = tvm.relay.transform.gradient(fn, mode='first_order')

counts = count_ops(gr)
assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)"

if __name__ == "__main__":
test_id()
test_add()
test_temp_add()
test_sub()
test_broadcast_add()
test_broadcast_subtract()
test_tuple()
test_tuple_first_order()
test_pow()
test_ref()
test_square_second_order()
test_if()
test_grad_tuple()
pytest.main([__file__])

0 comments on commit 35d00aa

Please sign in to comment.