From 35d00aa739bc3c3d85210c7224cfe90b1eb7c2cc Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 30 Jun 2020 07:35:06 +0200 Subject: [PATCH] Make first order gradient graphs more efficient Previously, nodes are visited as often as they are used and each time a derivative is computed. Only at the leaves were the contributions of everything added. This patch changes this to add at any node that is used several times. --- python/tvm/relay/testing/__init__.py | 16 ++++++++++++ src/relay/transforms/gradient.cc | 18 +++++++++---- tests/python/relay/test_pass_gradient.py | 32 +++++++++++++----------- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 8310a0202c17..404bb8012023 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -17,6 +17,7 @@ #pylint: disable=invalid-name """Utilities for testing and benchmarks""" from __future__ import absolute_import as _abs +import collections import numpy as np import tvm @@ -135,3 +136,18 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me def rand(dtype, *shape): return tvm.nd.array(np.random.rand(*shape).astype(dtype)) + + +def count_ops(expr): + """count number of times a given op is called in the graph""" + class OpCounter(tvm.relay.ExprVisitor): + def visit_call(self, call): + if hasattr(call, 'op'): + self.node_counter[call.op.name] += 1 + return super().visit_call(call) + def count(self, expr): + self.node_set = {} + self.node_counter = collections.Counter() + self.visit(expr) + return self.node_counter + return OpCounter().count(expr) diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 6fee40c51337..0cebba72c375 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -162,14 +162,24 @@ struct ADFunction : ADValueNode { }; struct FirstOrderReverseAD : ExprFunctor { + using TBase = ExprFunctor; const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping - std::unordered_map env; + std::unordered_map env; LetList* ll; FirstOrderReverseAD(LetList* ll) : ll(ll) {} + ADValue VisitExpr(const Expr& n) final { + if (env.count(n)) { + return env.at(n); + } + auto ret = TBase::VisitExpr(n); + env[n] = ret; + return ret; + } + ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; @@ -268,10 +278,8 @@ struct FirstOrderReverseAD : ExprFunctor { }); } - ADValue VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); - return env.at(v); - } + // Var will always be in env, handled in VisitExpr (without _), so we don't need + // to implement its VisitExpr_. }; Type GradRetType(const Function& f) { diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index e28eb4a6b249..4838c6a4e7fc 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import collections import numpy as np +import pytest import tvm from tvm import te @@ -23,7 +25,7 @@ from tvm.relay import create_executor, transform from tvm.relay.transform import gradient from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand +from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand, count_ops import tvm.relay.op as op @@ -309,17 +311,19 @@ def test_concat(): # no value validation as concatenate has dummy gradient right now. +def test_no_duplication(): + x = tvm.relay.Var('x', type_annotation=tvm.relay.TensorType([12, 12])) + y = tvm.relay.Var('y', type_annotation=tvm.relay.TensorType([12, 12])) + xy = tvm.relay.nn.dense(x, y) + + m = tvm.relay.sum(xy, keepdims=True) + s = tvm.relay.sum(xy - m) + fn = tvm.relay.Function([x,y], s) + fn = run_infer_type(fn) + gr = tvm.relay.transform.gradient(fn, mode='first_order') + + counts = count_ops(gr) + assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)" + if __name__ == "__main__": - test_id() - test_add() - test_temp_add() - test_sub() - test_broadcast_add() - test_broadcast_subtract() - test_tuple() - test_tuple_first_order() - test_pow() - test_ref() - test_square_second_order() - test_if() - test_grad_tuple() + pytest.main([__file__])