From c82fbdb1d850038c26c335803a6827ea1e92f620 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 6 Feb 2019 20:30:21 -0800 Subject: [PATCH] address comment --- src/relay/pass/gradient.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index fad4347549ef1..1b549ba97e31c 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -85,10 +85,10 @@ using ADValue = std::shared_ptr; /*! \brief AD over a program which generates a tensor output. */ struct ADTensor : ADValueNode { - Expr foward; + Expr forward; mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& foward) : - foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { } + ADTensor(LetList* ll, const Expr& forward) : + forward(ll->Push(forward)), reverse(ll->Push(ZeroLike(this->forward))) { } }; /*! \brief A staged representation of the program, we reflect @@ -123,7 +123,7 @@ struct FirstOrderReverseAD : ExprFunctor { const tvm::Array& type_args) { std::vector call_args; for (const ADValue& adval : args) { - call_args.push_back(adval->get().foward); + call_args.push_back(adval->get().forward); } auto orig = CallNode::make(op_ref, call_args, attrs, type_args); auto ret = std::make_shared(ll, orig); @@ -209,7 +209,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { auto c = rev->get().func(args, Attrs(), {}); const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { - res.reverse = OneLike(res.foward); + res.reverse = OneLike(res.forward); for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { @@ -221,7 +221,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { } return TupleNode::make(grad_res); }); - return Pair(res.foward, grad); + return Pair(res.forward, grad); }); return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); @@ -247,7 +247,7 @@ struct ReverseAD : ExprMutator { ReverseAD(const Var& bp) : bp(bp) { } Expr VisitExpr_(const OpNode* op) final { - CHECK(false) << "op should only be inside call"; + LOG(FATAL) << "op should only be inside call"; throw; }