Skip to content

Commit

Permalink
fix (apache#3550)
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon authored and wweic committed Sep 6, 2019
1 parent 6f59cd3 commit efdd1ca
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

TVM_REGISTER_API("relay._analysis.first_order_gradient")
TVM_REGISTER_API("relay._transform.first_order_gradient")
.set_body_typed(FirstOrderGradient);

struct ReverseADType : TypeMutator {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_id():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
back_func = run_infer_type(gradient(func))
back_func = run_infer_type(gradient(func, mode="first_order"))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
Expand Down

0 comments on commit efdd1ca

Please sign in to comment.