Skip to content

Commit

Permalink
update node instantiations
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Mar 23, 2020
1 parent 953791b commit 3e18cde
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions src/relay/transforms/gradient_cell.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {

Expr wrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
return CallNode::make(module_->GetConstructor("GradCell", "Raw"),
return Call(module_->GetConstructor("GradCell", "Raw"),
{expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
const Type& t = type_anno->fields[i];
fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t));
fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t));
}
Expr tuple = TupleNode::make(fields);
Expr tuple = Tuple(fields);
return tuple;
}

Expand Down Expand Up @@ -112,16 +112,16 @@ class OutputVisitor: public ExprFunctor<Expr(const Expr&, const Type&)> {
Expr unwrapExpr(const Expr expr, const Type& type) {
if (auto* type_call = type.as<TypeCallNode>()) {
if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) {
return CallNode::make(module_->GetGlobalVar("FromGradCell"), {expr});
return Call(module_->GetGlobalVar("FromGradCell"), {expr});
}
return expr;
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
const Type& t = type_anno->fields[i];
fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t));
fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t));
}
Expr tuple = TupleNode::make(fields);
Expr tuple = Tuple(fields);
return tuple;
}

Expand Down Expand Up @@ -166,15 +166,15 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type());
args.push_back(wrappedInput);
}
Expr transformedExpr = CallNode::make(GetRef<Function>(transformed), args);
Expr transformedExpr = Call(GetRef<Function>(transformed), args);

// unwrap outputs of GradCell type into Tensor type using OutputVisitor class
Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type);
return Function(f->params, tensorOutput, f->ret_type, Array<TypeVar>());
}

Expr VisitExpr_(const ConstantNode* op) final {
return CallNode::make(module_->GetConstructor("GradCell", "Raw"),
return Call(module_->GetConstructor("GradCell", "Raw"),
{GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
}

Expand All @@ -189,9 +189,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
tvm::Array<Expr> args;
// create add function
Type paramType = call_node->args[0]->checked_type();
tvm::Array<Var> params = {VarNode::make("lhs", paramType),
VarNode::make("rhs", paramType)};
Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]});
tvm::Array<Var> params = {Var("lhs", paramType),
Var("rhs", paramType)};
Expr callAdd = Call(Op::Get("add"), {params[0], params[1]});
Expr addTensorsFunc = Function(params, callAdd, paramType,
Array<TypeVar>());

Expand All @@ -200,17 +200,17 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
for (Expr expr : call_node->args) {
args.push_back(VisitExpr(expr));
}
return CallNode::make(addFunc, args, Attrs(), {paramType});
return Call(addFunc, args, Attrs(), {paramType});
} else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 &&
AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) {
// case: "multiply" between two tensors of the same size
const auto multFunc = module_->GetGlobalVar("MultiplyGradCell");
// create multiply function
tvm::Array<Expr> args;
Type paramType = call_node->args[0]->checked_type();
tvm::Array<Var> params = {VarNode::make("lhs", paramType),
VarNode::make("rhs", paramType)};
Expr callMultiply = CallNode::make(Op::Get("multiply"),
tvm::Array<Var> params = {Var("lhs", paramType),
Var("rhs", paramType)};
Expr callMultiply = Call(Op::Get("multiply"),
{params[0], params[1]});
Expr multTensorsFunc = Function(params, callMultiply, paramType,
Array<TypeVar>());
Expand All @@ -220,18 +220,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
for (Expr expr : call_node->args) {
args.push_back(VisitExpr(expr));
}
return CallNode::make(multFunc, args, Attrs(), {paramType});
return Call(multFunc, args, Attrs(), {paramType});
} else if (op_expr == Op::Get("ones")) {
// ones operator, use One constructor of GradCell
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
return CallNode::make(module_->GetConstructor("GradCell", "One"),
return Call(module_->GetConstructor("GradCell", "One"),
{func}, Attrs(), {call_node->checked_type()});
} else if (op_expr == Op::Get("zeros")) {
// zeros operator, use Zero constructor of GradCell
Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)},
{call_node->checked_type()}, {});
return CallNode::make(module_->GetConstructor("GradCell", "Zero"),
return Call(module_->GetConstructor("GradCell", "Zero"),
{func}, Attrs(), {call_node->checked_type()});
}

Expand All @@ -242,24 +242,24 @@ class GradientCellTransform: public ExprMutator, public TypeMutator {
tvm::Array<Expr> args;
// use FromGradCell to convert args to Tensor
for (Expr expr : call_node->args) {
args.push_back(CallNode::make(fromFunc,
{VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
args.push_back(Call(fromFunc,
{VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
}

const Expr tensorRes = CallNode::make(call_node->op, args);
const Expr tensorRes = Call(call_node->op, args);

if (op_expr == Op::Get("ones_like")) {
Expr onesFunction = Function({}, tensorRes,
{call_node->checked_type()}, Array<TypeVar>());
return CallNode::make(module_->GetConstructor("GradCell", "One"),
return Call(module_->GetConstructor("GradCell", "One"),
{onesFunction}, Attrs(), {call_node->checked_type()});
} else if (op_expr == Op::Get("zeros_like")) {
Expr zerosFunction = Function({}, tensorRes,
{call_node->checked_type()}, Array<TypeVar>());
return CallNode::make(module_->GetConstructor("GradCell", "Zero"),
return Call(module_->GetConstructor("GradCell", "Zero"),
{zerosFunction}, Attrs(), {call_node->checked_type()});
}
return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {tensorRes},
return Call(module_->GetConstructor("GradCell", "Raw"), {tensorRes},
Attrs(), {call_node->checked_type()});
}
// call-> op is not a relay op
Expand Down

0 comments on commit 3e18cde

Please sign in to comment.