diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index eb60176c1225..2c21c751a0e7 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -75,15 +75,15 @@ class InputVisitor: public ExprFunctor { Expr wrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return CallNode::make(module_->GetConstructor("GradCell", "Raw"), + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); } - Expr tuple = TupleNode::make(fields); + Expr tuple = Tuple(fields); return tuple; } @@ -112,16 +112,16 @@ class OutputVisitor: public ExprFunctor { Expr unwrapExpr(const Expr expr, const Type& type) { if (auto* type_call = type.as()) { if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return CallNode::make(module_->GetGlobalVar("FromGradCell"), {expr}); + return Call(module_->GetGlobalVar("FromGradCell"), {expr}); } return expr; } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); } - Expr tuple = TupleNode::make(fields); + Expr tuple = Tuple(fields); return tuple; } @@ -166,7 +166,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); args.push_back(wrappedInput); } - Expr transformedExpr = CallNode::make(GetRef(transformed), args); + Expr transformedExpr = Call(GetRef(transformed), args); // unwrap outputs of GradCell type into Tensor type using OutputVisitor class Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); @@ -174,7 +174,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(module_->GetConstructor("GradCell", "Raw"), + return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); } @@ -189,9 +189,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array args; // create add function Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {VarNode::make("lhs", paramType), - VarNode::make("rhs", paramType)}; - Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); + tvm::Array params = {Var("lhs", paramType), + Var("rhs", paramType)}; + Expr callAdd = Call(Op::Get("add"), {params[0], params[1]}); Expr addTensorsFunc = Function(params, callAdd, paramType, Array()); @@ -200,7 +200,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { for (Expr expr : call_node->args) { args.push_back(VisitExpr(expr)); } - return CallNode::make(addFunc, args, Attrs(), {paramType}); + return Call(addFunc, args, Attrs(), {paramType}); } else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "multiply" between two tensors of the same size @@ -208,9 +208,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { // create multiply function tvm::Array args; Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {VarNode::make("lhs", paramType), - VarNode::make("rhs", paramType)}; - Expr callMultiply = CallNode::make(Op::Get("multiply"), + tvm::Array params = {Var("lhs", paramType), + Var("rhs", paramType)}; + Expr callMultiply = Call(Op::Get("multiply"), {params[0], params[1]}); Expr multTensorsFunc = Function(params, callMultiply, paramType, Array()); @@ -220,18 +220,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { for (Expr expr : call_node->args) { args.push_back(VisitExpr(expr)); } - return CallNode::make(multFunc, args, Attrs(), {paramType}); + return Call(multFunc, args, Attrs(), {paramType}); } else if (op_expr == Op::Get("ones")) { // ones operator, use One constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(module_->GetConstructor("GradCell", "One"), + return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), {call_node->checked_type()}); } else if (op_expr == Op::Get("zeros")) { // zeros operator, use Zero constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(module_->GetConstructor("GradCell", "Zero"), + return Call(module_->GetConstructor("GradCell", "Zero"), {func}, Attrs(), {call_node->checked_type()}); } @@ -242,24 +242,24 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back(CallNode::make(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(Call(fromFunc, + {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } - const Expr tensorRes = CallNode::make(call_node->op, args); + const Expr tensorRes = Call(call_node->op, args); if (op_expr == Op::Get("ones_like")) { Expr onesFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(module_->GetConstructor("GradCell", "One"), + return Call(module_->GetConstructor("GradCell", "One"), {onesFunction}, Attrs(), {call_node->checked_type()}); } else if (op_expr == Op::Get("zeros_like")) { Expr zerosFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(module_->GetConstructor("GradCell", "Zero"), + return Call(module_->GetConstructor("GradCell", "Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); } - return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, + return Call(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); } // call-> op is not a relay op