Skip to content

Commit

Permalink
Make first order gradient graphs more efficient
Browse files Browse the repository at this point in the history
Previously, nodes are visited as often as they are used and each time a
derivative is computed. Only at the leaves were the contributions of
everything added. This patch changes this to add at any node that is
used several times.
  • Loading branch information
t-vi committed Jun 29, 2020
1 parent 3353b2d commit 0fbedf5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
18 changes: 13 additions & 5 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,24 @@ struct ADFunction : ADValueNode {
};

struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
using TBase = ExprFunctor<ADValue(const Expr&)>;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
std::unordered_map<Expr, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
LetList* ll;

FirstOrderReverseAD(LetList* ll) : ll(ll) {}

ADValue VisitExpr(const Expr& n) final {
if (env.count(n)) {
return env.at(n);
}
auto ret = TBase::VisitExpr(n);
env[n] = ret;
return ret;
}

ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined";
Expand Down Expand Up @@ -291,10 +301,8 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
});
}

ADValue VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return env.at(v);
}
// Var will always be in env, handled in VisitExpr (without _), so we don't need
// to implement its VisitExpr_.
};

Type GradRetType(const Function& f) {
Expand Down
44 changes: 31 additions & 13 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import collections
import numpy as np
import pytest

import tvm
from tvm import te
Expand All @@ -27,6 +29,20 @@
import tvm.relay.op as op


def count_ops(expr):
class OpCounter(tvm.relay.ExprVisitor):
def visit_call(self, expr):
if hasattr(expr, 'op'):
self.node_counter[expr.op.name] += 1
return super().visit_call(expr)
def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter
return OpCounter().count(expr)


def test_id():
shape = (10, 10)
dtype = 'float32'
Expand Down Expand Up @@ -309,17 +325,19 @@ def test_concat():
# no value validation as concatenate has dummy gradient right now.


def test_no_duplication():
x = tvm.relay.Var('x', type_annotation=tvm.relay.TensorType([12, 12]))
y = tvm.relay.Var('y', type_annotation=tvm.relay.TensorType([12, 12]))
xy = tvm.relay.nn.dense(x, y)

m = tvm.relay.sum(xy, keepdims=True)
s = tvm.relay.sum(xy - m)
fn = tvm.relay.Function([x,y], s)
fn = run_infer_type(fn)
gr = tvm.relay.transform.gradient(fn, mode='first_order')

counts = count_ops(gr)
assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)"

if __name__ == "__main__":
test_id()
test_add()
test_temp_add()
test_sub()
test_broadcast_add()
test_broadcast_subtract()
test_tuple()
test_tuple_first_order()
test_pow()
test_ref()
test_square_second_order()
test_if()
test_grad_tuple()
pytest.main([__file__])

0 comments on commit 0fbedf5

Please sign in to comment.