From 09efdc9d6c92f2817c3d4a2d8cd29dd374da3704 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 22 Oct 2018 22:33:37 -0700 Subject: [PATCH] [Relay] Fix format (#1957) * save * fix format --- src/relay/ir/base.cc | 6 +++--- src/relay/ir/environment.cc | 4 ++-- src/relay/ir/expr.cc | 34 +++++++++++++++++----------------- src/relay/ir/type.cc | 28 ++++++++++++++-------------- src/relay/pass/alpha_eq.cc | 11 ++++++++--- src/relay/pass/kind_check.cc | 6 +++--- src/relay/pass/type_infer.cc | 2 +- src/relay/pass/util.cc | 14 +++++++------- src/relay/pass/well_formed.cc | 12 ++++++------ 9 files changed, 61 insertions(+), 56 deletions(-) diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 4e71444bf1ae..a68910e56b71 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) { } TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const SourceNameNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const SourceNameNode* node, tvm::IRPrinter* p) { p->stream << "SourceName(" << node->name << ", " << node << ")"; }); @@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_API("relay._make.Span") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = SpanNode::make(args[0], args[1], args[2]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const SpanNode* node, tvm::IRPrinter* p) { p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " << node->col_offset << ")"; }); diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 6dfaa0b24a53..dddad82c8afc 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) { return (*it).second; } -Function EnvironmentNode::Lookup(const std::string &name) { +Function EnvironmentNode::Lookup(const std::string& name) { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } -void EnvironmentNode::Update(const Environment &env) { +void EnvironmentNode::Update(const Environment& env) { for (auto pair : env->functions) { this->Update(pair.first, pair.second); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 2d373b769559..c75c414c8ce9 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_API("relay._make.Constant") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ConstantNode::make(args[0]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const ConstantNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const ConstantNode* node, tvm::IRPrinter* p) { p->stream << "Constant(TODO)"; }); @@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_API("relay._make.Tuple") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TupleNode::make(args[0]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const TupleNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const TupleNode* node, tvm::IRPrinter* p) { p->stream << "Tuple(" << node->fields << ")"; }); @@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_API("relay._make.Var") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = VarNode::make(args[0], args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const VarNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const VarNode* node, tvm::IRPrinter* p) { p->stream << "Var(" << node->name_hint; if (node->type_annotation.defined()) { p->stream << ", ty="; @@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_API("relay._make.GlobalVar") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GlobalVarNode::make(args[0]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const GlobalVarNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const GlobalVarNode* node, tvm::IRPrinter* p) { p->stream << "GlobalVar(" << node->name_hint << ")"; }); @@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_API("relay._make.Function") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const FunctionNode *node, - tvm::IRPrinter *p) { +.set_dispatch([](const FunctionNode* node, + tvm::IRPrinter* p) { p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body << ", " << node->type_params << ")"; }); @@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array args, Attrs attrs, TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_API("relay._make.Call") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = CallNode::make(args[0], args[1], args[2], args[3]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const CallNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " << node->type_args << ")"; }); @@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_API("relay._make.Let") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = LetNode::make(args[0], args[1], args[2]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const LetNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const LetNode* node, tvm::IRPrinter* p) { p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; }); @@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = IfNode::make(args[0], args[1], args[2]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const IfNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const IfNode* node, tvm::IRPrinter* p) { p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " << node->false_branch << ")"; }); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 39347adced92..d6fc2e85b2d8 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_API("relay._make.TensorType") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { Array shape = args[0]; *ret = TensorTypeNode::make(shape, args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const TensorTypeNode *node, - tvm::IRPrinter *p) { +.set_dispatch([](const TensorTypeNode* node, + tvm::IRPrinter* p) { p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); @@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_API("relay._make.TypeVar") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[1]; *ret = TypeVarNode::make(args[0], static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const TypeVarNode *node, - tvm::IRPrinter *p) { +.set_dispatch([](const TypeVarNode* node, + tvm::IRPrinter* p) { p->stream << "TypeVarNode(" << node->var->name_hint << ", " << node->kind << ")"; }); @@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_API("relay._make.FuncType") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const FuncTypeNode *node, - tvm::IRPrinter *p) { +.set_dispatch([](const FuncTypeNode* node, + tvm::IRPrinter* p) { p->stream << "FuncTypeNode(" << node->type_params << ", " << node->arg_types << ", " << node->ret_type << ", " << node->type_constraints << ")"; @@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_API("relay._make.TypeRelation") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const TypeRelationNode *node, tvm::IRPrinter *p) { +.set_dispatch([](const TypeRelationNode* node, tvm::IRPrinter* p) { p->stream << "TypeRelationNode(" << node->func->name << ", " << node->args << ")"; @@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array fields) { TVM_REGISTER_NODE_TYPE(TupleTypeNode); TVM_REGISTER_API("relay._make.TupleType") -.set_body([](TVMArgs args, TVMRetValue *ret) { +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = TupleTypeNode::make(args[0]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const TupleTypeNode *node, - tvm::IRPrinter *p) { +.set_dispatch([](const TupleTypeNode* node, + tvm::IRPrinter* p) { p->stream << "TupleTypeNode(" << node->fields << ")"; }); diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 56aeefda78f1..41ec3f1e090b 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor { }; bool AlphaEqual(const Type& t1, const Type& t2) { - if (t1.defined() != t2.defined()) + if (t1.defined() != t2.defined()) { return false; + } - if (!t1.defined()) + if (!t1.defined()) { return true; + } TypeAlphaEq aeq; aeq.VisitType(t1, t2); @@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor { for (size_t i = 0; i < func1->params.size(); ++i) { MergeVarDecl(func1->params[i], func2->params[i]); } - if (!equal) return; + + if (!equal) { + return; + } for (size_t i = 0U; i < func1->type_params.size(); i++) { equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 3f4d81b7e24f..8fd77a71ec4b 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> { // checks if t is an incomplete node of kind k or a type param of kind k bool MatchKind(const Type& t, Kind k) { - if (const IncompleteTypeNode *tv = t.as()) { + if (const IncompleteTypeNode* tv = t.as()) { return tv->kind == k; } - if (const TypeVarNode *tp = t.as()) { + if (const TypeVarNode* tp = t.as()) { return tp->kind == k; } @@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> { } } - bool Check(const Type &t) { + bool Check(const Type& t) { this->VisitType(t); return valid; } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3e233274af2e..0cbce833aed9 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator { return new_e; } - Type VisitType(const Type &t) final { + Type VisitType(const Type& t) final { return solver_->Resolve(t); } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 8ebac921203f..ff4bb55b7b3c 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -14,10 +14,10 @@ namespace relay { class FreeVar; class FreeTypeVar : private TypeVisitor<> { - std::unordered_set * free_vars; - std::unordered_set * bound_vars; - FreeTypeVar(std::unordered_set * free_vars, - std::unordered_set * bound_vars) : + std::unordered_set* free_vars; + std::unordered_set* bound_vars; + FreeTypeVar(std::unordered_set* free_vars, + std::unordered_set* bound_vars) : free_vars(free_vars), bound_vars(bound_vars) { } void VisitType_(const TypeVarNode* tp) final { @@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> { }; class FreeVar : public ExprVisitor { - void VisitExpr_(const VarNode *v) final { + void VisitExpr_(const VarNode* v) final { auto var = GetRef(v); if (bound_vars.count(var) == 0) { free_vars.insert(var); @@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor { } } - void VisitExpr_(const FunctionNode *f) final { + void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { bound_types.insert(tp); } @@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor { VisitType(f->ret_type); } - void VisitExpr_(const LetNode *l) final { + void VisitExpr_(const LetNode* l) final { bound_vars.insert(l->var); VisitExpr(l->value); VisitExpr(l->body); diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index e008a72e5d90..a37969f9e317 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -18,14 +18,14 @@ class WellFormedChecker : private ExprVisitor { std::unordered_set s; - void Check(const Var & v) { + void Check(const Var& v) { if (s.count(v) != 0) { well_formed = false; } s.insert(v); } - void VisitExpr_(const LetNode * l) final { + void VisitExpr_(const LetNode* l) final { // we do letrec only for FunctionNode, // but shadowing let in let binding is likely programming error, and we should forbidden it. Check(l->var); @@ -33,21 +33,21 @@ class WellFormedChecker : private ExprVisitor { CheckWellFormed(l->body); } - void VisitExpr_(const FunctionNode * f) final { - for (const Var & param : f->params) { + void VisitExpr_(const FunctionNode* f) final { + for (const Var& param : f->params) { Check(param); } CheckWellFormed(f->body); } public: - bool CheckWellFormed(const Expr & e) { + bool CheckWellFormed(const Expr& e) { this->VisitExpr(e); return well_formed; } }; -bool WellFormed(const Expr & e) { +bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); }