Skip to content

Commit

Permalink
[Relay] Fix ad for conditional expression (#3453)
Browse files Browse the repository at this point in the history
* save

* fix
  • Loading branch information
MarisaKirisame authored and vinx13 committed Jun 28, 2019
1 parent 8fe715f commit 329378c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -311,6 +311,12 @@ struct ReverseAD : ExprMutator {
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}

Expr VisitExpr_(const IfNode* op) final {
return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
VisitExpr(op->true_branch),
VisitExpr(op->false_branch));
}

Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ def test_square_second_order():
tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))


def test_if():
x = relay.var("x", shape=(1, 16, 64, 64))
y = relay.var("y", shape=(1, 16, 64, 64))
cond = relay.var("cond", shape=(), dtype='uint1')
net = relay.If(cond, x, y)
net = relay.log(net)
net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))


if __name__ == "__main__":
test_id()
test_add()
Expand All @@ -242,3 +252,4 @@ def test_square_second_order():
test_pow()
test_ref()
test_square_second_order()
test_if()

0 comments on commit 329378c

Please sign in to comment.