Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Fix format #1957

Merged
merged 2 commits into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/relay/ir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) {
}

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")";
});

Expand All @@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode);

TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode *node, tvm::IRPrinter *p) {
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) {
return (*it).second;
}

Function EnvironmentNode::Lookup(const std::string &name) {
Function EnvironmentNode::Lookup(const std::string& name) {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}

void EnvironmentNode::Update(const Environment &env) {
void EnvironmentNode::Update(const Environment& env) {
for (auto pair : env->functions) {
this->Update(pair.first, pair.second);
}
Expand Down
34 changes: 17 additions & 17 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode);

TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstantNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node, tvm::IRPrinter *p) {
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
p->stream << "Constant(TODO)";
});

Expand All @@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode);

TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
p->stream << "Tuple(" << node->fields << ")";
});

Expand All @@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode);

TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
p->stream << "Var(" << node->name_hint;
if (node->type_annotation.defined()) {
p->stream << ", ty=";
Expand All @@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);

TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GlobalVarNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node, tvm::IRPrinter *p) {
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
p->stream << "GlobalVar(" << node->name_hint << ")";
});

Expand Down Expand Up @@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const {
TVM_REGISTER_NODE_TYPE(FunctionNode);

TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode *node,
tvm::IRPrinter *p) {
.set_dispatch<FunctionNode>([](const FunctionNode* node,
tvm::IRPrinter* p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")";
});
Expand All @@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode);

TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode *node, tvm::IRPrinter *p) {
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});
Expand All @@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode);

TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LetNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")";
});
Expand All @@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {

TVM_REGISTER_NODE_TYPE(IfNode);

TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) {
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")";
});
Expand Down
28 changes: 14 additions & 14 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);

TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});

Expand All @@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);

TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret =
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const TypeVarNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TypeVarNode>([](const TypeVarNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
Expand Down Expand Up @@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode);

TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")";
Expand All @@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode);

TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) {
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
p->stream << "TypeRelationNode("
<< node->func->name
<< ", " << node->args << ")";
Expand All @@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode);

TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleTypeNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode *node,
tvm::IRPrinter *p) {
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "TupleTypeNode(" << node->fields << ")";
});

Expand Down
11 changes: 8 additions & 3 deletions src/relay/pass/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
};

bool AlphaEqual(const Type& t1, const Type& t2) {
if (t1.defined() != t2.defined())
if (t1.defined() != t2.defined()) {
return false;
}

if (!t1.defined())
if (!t1.defined()) {
return true;
}

TypeAlphaEq aeq;
aeq.VisitType(t1, t2);
Expand Down Expand Up @@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
for (size_t i = 0; i < func1->params.size(); ++i) {
MergeVarDecl(func1->params[i], func2->params[i]);
}
if (!equal) return;

if (!equal) {
return;
}

for (size_t i = 0U; i < func1->type_params.size(); i++) {
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
Expand Down
6 changes: 3 additions & 3 deletions src/relay/pass/kind_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> {

// checks if t is an incomplete node of kind k or a type param of kind k
bool MatchKind(const Type& t, Kind k) {
if (const IncompleteTypeNode *tv = t.as<IncompleteTypeNode>()) {
if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) {
return tv->kind == k;
}

if (const TypeVarNode *tp = t.as<TypeVarNode>()) {
if (const TypeVarNode* tp = t.as<TypeVarNode>()) {
return tp->kind == k;
}

Expand Down Expand Up @@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> {
}
}

bool Check(const Type &t) {
bool Check(const Type& t) {
this->VisitType(t);
return valid;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator {
return new_e;
}

Type VisitType(const Type &t) final {
Type VisitType(const Type& t) final {
return solver_->Resolve(t);
}

Expand Down
14 changes: 7 additions & 7 deletions src/relay/pass/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ namespace relay {

class FreeVar;
class FreeTypeVar : private TypeVisitor<> {
std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual> * free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual> * bound_vars) :
std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars;
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars;
FreeTypeVar(std::unordered_set<TypeVar, NodeHash, NodeEqual>* free_vars,
std::unordered_set<TypeVar, NodeHash, NodeEqual>* bound_vars) :
free_vars(free_vars), bound_vars(bound_vars) { }

void VisitType_(const TypeVarNode* tp) final {
Expand Down Expand Up @@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> {
};

class FreeVar : public ExprVisitor {
void VisitExpr_(const VarNode *v) final {
void VisitExpr_(const VarNode* v) final {
auto var = GetRef<Var>(v);
if (bound_vars.count(var) == 0) {
free_vars.insert(var);
Expand All @@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor {
}
}

void VisitExpr_(const FunctionNode *f) final {
void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) {
bound_types.insert(tp);
}
Expand All @@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor {
VisitType(f->ret_type);
}

void VisitExpr_(const LetNode *l) final {
void VisitExpr_(const LetNode* l) final {
bound_vars.insert(l->var);
VisitExpr(l->value);
VisitExpr(l->body);
Expand Down
12 changes: 6 additions & 6 deletions src/relay/pass/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,36 @@ class WellFormedChecker : private ExprVisitor {

std::unordered_set<Var, NodeHash, NodeEqual> s;

void Check(const Var & v) {
void Check(const Var& v) {
if (s.count(v) != 0) {
well_formed = false;
}
s.insert(v);
}

void VisitExpr_(const LetNode * l) final {
void VisitExpr_(const LetNode* l) final {
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
CheckWellFormed(l->value);
CheckWellFormed(l->body);
}

void VisitExpr_(const FunctionNode * f) final {
for (const Var & param : f->params) {
void VisitExpr_(const FunctionNode* f) final {
for (const Var& param : f->params) {
Check(param);
}
CheckWellFormed(f->body);
}

public:
bool CheckWellFormed(const Expr & e) {
bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e);
return well_formed;
}
};

bool WellFormed(const Expr & e) {
bool WellFormed(const Expr& e) {
return WellFormedChecker().CheckWellFormed(e);
}

Expand Down